In [97]:
from hope.probabilistic_nlg.utils import calculate_bleu_scores, calculate_ngram_diversity, calculate_entropy, re_tokenize_sequence
from hope.probabilistic_nlg.evaluate_latent_space import get_avg_sent_lengths

In [130]:
from pathlib import Path
import pickle
import numpy as np

In [82]:
absolute_path = '/mnt/hope/hope/probabilistic_nlg/snli/wae_det'
with Path(absolute_path, 'tokenizer.pickle').open('rb') as handle:
    orig_tokenizer = pickle.load(handle)
with Path(absolute_path, 'word_index.pickle').open('rb') as handle:
    word_index = pickle.load(handle)

In [83]:
inv_word_index = {value:key for key,value in word_index.items()}

In [114]:
def tokenizer(sentence):
    tokens = re_tokenize_sequence([sentence], orig_tokenizer, 50, 30000, word_index)[0]
    tokens = [inv_word_index[t] for t in tokens if inv_word_index[t] not in ['EOS', 'PAD', 'UNK']]
    return tokens

In [85]:
files = {
    "ref": './output/ref.txt',
    "par": './output/par.txt',
    "cgmh": './output/CGMH/output.txt100',
    "vae": './output/VAE/output.txt100',
    "wae_det": './output/WAE-DET/output.txt100',
    "wae_st": './output/WAE-ST/output.txt100',
}

for filename in files:
    files[filename] = open(files[filename]).readlines()
    files[filename] = [l.lower().strip() for l in files[filename]]

In [86]:
max_lines = min([len(files[f]) for f in files])

In [87]:
max_lines

1075

# BLEU

In [125]:
print('BLEU_ref:')
for name, lines in files.items():
    # BLEU
    sentences = lines[:max_lines]
    print(name, calculate_bleu_scores([[s] for s in files['ref'][:max_lines]], sentences))

BLEU_ref:
ref (100.0, 100.0, 100.0, 100.0)
par (80.32, 73.86, 69.01, 65.0)
cgmh (83.27, 77.91, 74.48, 71.73)
vae (82.08, 76.27, 72.65, 69.85)
wae_det (83.58, 78.09, 74.63, 71.89)
wae_st (83.13, 77.8, 74.41, 71.72)


In [124]:
print('BLEU_orig:')
for name, lines in files.items():
    # BLEU
    sentences = lines[:max_lines]
    print(name, calculate_bleu_scores([[s] for s in files['par'][:max_lines]], sentences))

BLEU_orig:
ref (80.31, 73.86, 69.02, 65.02)
par (100.0, 100.0, 100.0, 100.0)
cgmh (74.11, 64.19, 57.49, 52.42)
vae (74.66, 64.1, 57.01, 51.73)
wae_det (74.13, 64.24, 57.56, 52.53)
wae_st (74.0, 64.05, 57.39, 52.36)


# Entropy

In [131]:
print('ENTROPY:')
for name, lines in files.items():
    # Entropy
    ts = [tokenizer(s) for s in lines]
    tokens = []
    for t in ts:
        tokens += t
        
    # Average Sentence Length
    avglen = np.mean([len(tokens) for tokens in ts])
    
    print(f'{name}: Average Length: {avglen}', f'Entropy: {calculate_entropy(tokens)}')

ENTROPY:
ref: Average Length: 8.188995263394144 Entropy: 5.805686176695064
par: Average Length: 8.200726234900813 Entropy: 5.803647001339448
cgmh: Average Length: 9.444989488437281 Entropy: 5.88056232197082
vae: Average Length: 10.03813953488372 Entropy: 5.728229526317101
wae_det: Average Length: 9.412612612612612 Entropy: 5.824152983276227
wae_st: Average Length: 9.557663125948407 Entropy: 5.731383411366381


# Perplexity

In [136]:
import math
from pytorch_pretrained_bert import OpenAIGPTTokenizer, OpenAIGPTModel, OpenAIGPTLMHeadModel

In [138]:
# Load pre-trained model (weights)
model = OpenAIGPTLMHeadModel.from_pretrained('openai-gpt')



  0%|          | 0/478750579 [00:00<?, ?B/s][A[A

  0%|          | 52224/478750579 [00:00<19:22, 411913.67B/s][A[A

  0%|          | 261120/478750579 [00:00<15:01, 530655.93B/s][A[A

  0%|          | 905216/478750579 [00:00<10:53, 731088.84B/s][A[A

  0%|          | 2357248/478750579 [00:00<07:45, 1022339.68B/s][A[A

  1%|          | 4589568/478750579 [00:00<05:31, 1432371.27B/s][A[A

  2%|▏         | 8207360/478750579 [00:00<03:53, 2012098.98B/s][A[A

  2%|▏         | 10079232/478750579 [00:00<02:51, 2739166.43B/s][A[A

  3%|▎         | 12525568/478750579 [00:00<02:04, 3733865.99B/s][A[A

  3%|▎         | 15024128/478750579 [00:00<01:32, 5011460.70B/s][A[A

  4%|▍         | 18685952/478750579 [00:01<01:08, 6762509.21B/s][A[A

  5%|▍         | 21922816/478750579 [00:01<00:51, 8866778.14B/s][A[A

  5%|▌         | 24716288/478750579 [00:01<00:40, 11146384.99B/s][A[A

  6%|▌         | 27507712/478750579 [00:01<00:38, 11825211.30B/s][A[A

  6%|▌         | 2987

 12%|█▏        | 55681024/478750579 [00:37<08:21, 844335.01B/s][A[A

 12%|█▏        | 55812096/478750579 [00:37<08:05, 871824.53B/s][A[A

 12%|█▏        | 55926784/478750579 [00:37<07:53, 893357.67B/s][A[A

 12%|█▏        | 56041472/478750579 [00:37<07:30, 938252.89B/s][A[A

 12%|█▏        | 56172544/478750579 [00:37<07:20, 960228.97B/s][A[A

 12%|█▏        | 56292352/478750579 [00:37<06:53, 1020988.10B/s][A[A

 12%|█▏        | 56396800/478750579 [00:38<07:08, 984752.65B/s] [A[A

 12%|█▏        | 56497152/478750579 [00:38<07:42, 912416.15B/s][A[A

 12%|█▏        | 56614912/478750579 [00:38<07:40, 916500.93B/s][A[A

 12%|█▏        | 56713216/478750579 [00:38<07:31, 934946.39B/s][A[A

 12%|█▏        | 56827904/478750579 [00:38<07:06, 989717.04B/s][A[A

 12%|█▏        | 56934400/478750579 [00:38<06:57, 1011105.00B/s][A[A

 12%|█▏        | 57040896/478750579 [00:38<06:53, 1020983.45B/s][A[A

 12%|█▏        | 57188352/478750579 [00:38<06:39, 1054630.52B/s][A[A

 

 19%|█▊        | 89071616/478750579 [01:06<07:18, 888796.64B/s][A[A

 19%|█▊        | 89186304/478750579 [01:06<07:15, 894691.77B/s][A[A

 19%|█▊        | 89317376/478750579 [01:06<06:59, 927967.79B/s][A[A

 19%|█▊        | 89464832/478750579 [01:06<06:38, 976527.80B/s][A[A

 19%|█▊        | 89595904/478750579 [01:06<06:33, 988724.51B/s][A[A

 19%|█▊        | 89726976/478750579 [01:06<06:27, 1004743.81B/s][A[A

 19%|█▉        | 89858048/478750579 [01:06<06:00, 1079534.34B/s][A[A

 19%|█▉        | 89968640/478750579 [01:07<05:58, 1084390.45B/s][A[A

 19%|█▉        | 90087424/478750579 [01:07<06:16, 1031124.41B/s][A[A

 19%|█▉        | 90234880/478750579 [01:07<06:05, 1063304.34B/s][A[A

 19%|█▉        | 90398720/478750579 [01:07<05:49, 1111113.96B/s][A[A

 19%|█▉        | 90546176/478750579 [01:07<05:59, 1080035.97B/s][A[A

 19%|█▉        | 90710016/478750579 [01:07<05:43, 1129755.11B/s][A[A

 19%|█▉        | 90857472/478750579 [01:07<05:46, 1117858.55B/s][A[

 29%|██▊       | 137437184/478750579 [01:34<04:46, 1189651.46B/s][A[A

 29%|██▊       | 137617408/478750579 [01:35<04:37, 1227169.55B/s][A[A

 29%|██▉       | 137814016/478750579 [01:35<04:24, 1287012.36B/s][A[A

 29%|██▉       | 137994240/478750579 [01:35<04:18, 1317840.93B/s][A[A

 29%|██▉       | 138174464/478750579 [01:35<04:08, 1370494.93B/s][A[A

 29%|██▉       | 138371072/478750579 [01:35<04:12, 1347791.24B/s][A[A

 29%|██▉       | 138567680/478750579 [01:35<04:03, 1395213.36B/s][A[A

 29%|██▉       | 138764288/478750579 [01:35<03:57, 1431108.12B/s][A[A

 29%|██▉       | 138960896/478750579 [01:36<03:59, 1418247.92B/s][A[A

 29%|██▉       | 139157504/478750579 [01:36<03:54, 1447000.82B/s][A[A

 29%|██▉       | 139370496/478750579 [01:36<03:48, 1483492.77B/s][A[A

 29%|██▉       | 139569152/478750579 [01:36<03:31, 1605308.78B/s][A[A

 29%|██▉       | 139780096/478750579 [01:36<03:44, 1512198.09B/s][A[A

 29%|██▉       | 139993088/478750579 [01:36<03:39, 

 37%|███▋      | 176594944/478750579 [02:03<03:05, 1628446.20B/s][A[A

 37%|███▋      | 176824320/478750579 [02:03<03:00, 1671763.36B/s][A[A

 37%|███▋      | 177053696/478750579 [02:03<02:57, 1702476.04B/s][A[A

 37%|███▋      | 177299456/478750579 [02:03<02:54, 1726363.16B/s][A[A

 37%|███▋      | 177528832/478750579 [02:04<02:49, 1778374.70B/s][A[A

 37%|███▋      | 177774592/478750579 [02:04<02:55, 1713553.46B/s][A[A

 37%|███▋      | 177947648/478750579 [02:04<03:47, 1322569.62B/s][A[A

 37%|███▋      | 178095104/478750579 [02:04<03:59, 1256193.99B/s][A[A

 37%|███▋      | 178232320/478750579 [02:04<04:12, 1191416.94B/s][A[A

 37%|███▋      | 178360320/478750579 [02:04<04:28, 1119629.31B/s][A[A

 37%|███▋      | 178479104/478750579 [02:04<04:45, 1052499.87B/s][A[A

 37%|███▋      | 178590720/478750579 [02:05<05:02, 993680.86B/s] [A[A

 37%|███▋      | 178708480/478750579 [02:05<05:09, 968508.96B/s][A[A

 37%|███▋      | 178839552/478750579 [02:05<05:05, 9

 49%|████▊     | 233054208/478750579 [02:31<01:51, 2209394.23B/s][A[A

 49%|████▊     | 233278464/478750579 [02:31<01:55, 2126235.84B/s][A[A

 49%|████▉     | 233512960/478750579 [02:31<01:54, 2149133.01B/s][A[A

 49%|████▉     | 233730048/478750579 [02:31<01:56, 2097169.78B/s][A[A

 49%|████▉     | 233971712/478750579 [02:31<01:53, 2156912.22B/s][A[A

 49%|████▉     | 234250240/478750579 [02:31<01:50, 2222428.95B/s][A[A

 49%|████▉     | 234479616/478750579 [02:31<01:49, 2224007.87B/s][A[A

 49%|████▉     | 234725376/478750579 [02:31<01:47, 2270865.81B/s][A[A

 49%|████▉     | 235003904/478750579 [02:31<01:46, 2293147.29B/s][A[A

 49%|████▉     | 235249664/478750579 [02:31<01:45, 2314793.80B/s][A[A

 49%|████▉     | 235482112/478750579 [02:32<01:48, 2236687.85B/s][A[A

 49%|████▉     | 235773952/478750579 [02:32<01:44, 2321346.94B/s][A[A

 49%|████▉     | 236079104/478750579 [02:32<01:37, 2500834.70B/s][A[A

 49%|████▉     | 236335104/478750579 [02:32<01:42, 

 63%|██████▎   | 301659136/478750579 [02:59<00:30, 5870670.93B/s][A[A

 63%|██████▎   | 302391296/478750579 [02:59<00:30, 5796629.33B/s][A[A

 63%|██████▎   | 303308800/478750579 [02:59<00:28, 6131240.71B/s][A[A

 64%|██████▎   | 304259072/478750579 [02:59<00:27, 6414334.51B/s][A[A

 64%|██████▍   | 305225728/478750579 [02:59<00:25, 6753686.54B/s][A[A

 64%|██████▍   | 306128896/478750579 [03:00<00:23, 7298143.30B/s][A[A

 64%|██████▍   | 306879488/478750579 [03:00<00:24, 7138123.11B/s][A[A

 64%|██████▍   | 307798016/478750579 [03:00<00:23, 7420966.28B/s][A[A

 65%|██████▍   | 308879360/478750579 [03:00<00:22, 7688178.60B/s][A[A

 65%|██████▍   | 309993472/478750579 [03:00<00:21, 7919613.50B/s][A[A

 65%|██████▍   | 311140352/478750579 [03:00<00:20, 8179668.32B/s][A[A

 65%|██████▌   | 312336384/478750579 [03:00<00:19, 8426848.89B/s][A[A

 65%|██████▌   | 313565184/478750579 [03:00<00:19, 8658098.29B/s][A[A

 66%|██████▌   | 314826752/478750579 [03:01<00:18, 

In [143]:
# model.eval()

In [190]:
import torch
from tqdm import tqdm

In [155]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')



In [180]:
def get_score(sentence):
    tokenize_input = tokenizer.tokenize(sentence)
    tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
    loss=model(tensor_input, lm_labels=tensor_input)
    return math.exp(loss)

In [None]:
sentences = files['vae']
scores = []
for s in tqdm(sentences):
    scores.append(get_score(s))

In [None]:
np.mean(scores)