In [1]:
from __future__ import print_function

import os
import time
import numpy as np
import tensorflow as tf
import pandas as pd
from collections import defaultdict

from sklearn.metrics import roc_auc_score, accuracy_score
import nltk

from correct_text import train, decode, decode_sentence, create_model, DefaultPTBConfig, DefaultMovieDialogConfig
from text_correcter_data_readers import PTBDataReader, MovieDialogReader
from text_correcter_models import InputBiasedNGramModel

%matplotlib inline

In [2]:
root_data_path = "/Users/atpaino/data/textcorrecter/dialog_corpus"
train_path = os.path.join(root_data_path, "cleaned_dialog_train.txt")
val_path = os.path.join(root_data_path, "cleaned_dialog_val.txt")
test_path = os.path.join(root_data_path, "cleaned_dialog_test.txt")
model_path = os.path.join(root_data_path, "dialog_correcter_model")
config = DefaultMovieDialogConfig()

## Train

In [3]:
data_reader = MovieDialogReader(config, train_path)

In [4]:
train(data_reader, train_path, val_path, model_path)

Reading data; train = /Users/atpaino/data/textcorrecter/dialog_corpus/cleaned_dialog_train.txt, test = /Users/atpaino/data/textcorrecter/dialog_corpus/cleaned_dialog_val.txt
Creating 2 layers of 512 units.
Created model with fresh parameters.
Training bucket sizes: [226666, 98064, 56724, 80504]
Total train size: 461958.0
global step 100 learning rate 0.5000 step-time 4.61 perplexity 2143.50
  eval: bucket 0 perplexity 220.79
  eval: bucket 1 perplexity 340.07
  eval: bucket 2 perplexity 390.61
  eval: bucket 3 perplexity 604.10
global step 200 learning rate 0.5000 step-time 4.27 perplexity 194.30
  eval: bucket 0 perplexity 69.65
  eval: bucket 1 perplexity 144.68
  eval: bucket 2 perplexity 218.88
  eval: bucket 3 perplexity 315.74
global step 300 learning rate 0.5000 step-time 4.65 perplexity 94.97
  eval: bucket 0 perplexity 30.62
  eval: bucket 1 perplexity 97.79
  eval: bucket 2 perplexity 158.51
  eval: bucket 3 perplexity 225.04
global step 400 learning rate 0.5000 step-time 4.3

## Decode sentences

In [3]:
data_reader = MovieDialogReader(config, train_path, dropout_prob=0.25, replacement_prob=0.25, dataset_copies=1)

In [4]:
ngram_model = InputBiasedNGramModel(data_reader, train_path)

In [6]:
ngram_model.prob("hello", [], ["hello", "friend"])

0.800534625413185

In [9]:
ngram_model.prob("friend", [], ["hello", "friend"])

0.3200131397951014

In [5]:
ngram_model.prob("friend", ["hello"], ["hello", "friend"])

0.8

In [5]:
sess = tf.InteractiveSession()
model = create_model(sess, True, model_path, config=config)

Reading model parameters from /Users/atpaino/data/textcorrecter/dialog_corpus/dialog_correcter_model/translate.ckpt-20000


In [9]:
# Test a sample from the test dataset.
decoded = decode_sentence(sess, model, data_reader, "you have girlfriend", ngram_model=ngram_model)

adj prob of she is 19.5401934555, orig prob is 19.5385818481
adj prob of to is 19.8772380668, orig prob is 19.8767967224
adj prob of for is 19.9113501505, orig prob is 19.9108562469
adj prob of of is 19.6186662088, orig prob is 19.617980957
adj prob of 've is 20.2236632214, orig prob is 20.2235832214
adj prob of have is 23.7042996892, orig prob is 23.4905319214
adj prob of you is 45.0243129319, orig prob is 44.2111854553
adj prob of we is 22.7613260914, orig prob is 22.7578716278
adj prob of they is 21.8396258412, orig prob is 21.8379173279
adj prob of i is 25.0736176086, orig prob is 25.0479240417
Using token you
adj prob of shall is 12.1456444127, orig prob is 12.1455154419
adj prob of the is 13.0883536896, orig prob is 13.0876674652
adj prob of used is 12.3276011519, orig prob is 12.327372551
adj prob of must is 13.7987719708, orig prob is 13.7976112366
adj prob of a is 13.773842809, orig prob is 13.7722835541
adj prob of do is 12.6965355754, orig prob is 12.6890964508
adj prob of g

In [12]:
decoded = decode_sentence(sess, model, data_reader, "blablahblah went to the market", ngram_model=ngram_model)

adj prob of very is 7.40181912493, orig prob is 7.40153217316
adj prob of even is 7.53177624196, orig prob is 7.53161001205
adj prob of 's is 9.45088200842, orig prob is 9.45079708099
adj prob of one is 8.95518184077, orig prob is 8.95480537415
adj prob of 'll is 7.55581034286, orig prob is 7.55573034286
adj prob of his is 8.7218195339, orig prob is 8.72158432007
adj prob of an is 13.3071738617, orig prob is 13.3069534302
adj prob of the is 17.3503058683, orig prob is 17.2191810608
adj prob of a is 15.3809703451, orig prob is 15.3795566559
adj prob of blablahblah is 24.165336391, orig prob is 23.365316391
Using token blablahblah
adj prob of vision is 10.2649544583, orig prob is 10.2648744583
adj prob of 've is 10.3702906476, orig prob is 10.3702106476
adj prob of aware is 10.7392138348, orig prob is 10.7391338348
adj prob of followed is 12.1530560361, orig prob is 12.1529760361
adj prob of the is 10.9985148697, orig prob is 10.8705148697
adj prob of about is 10.7583865033, orig prob is

In [11]:
decoded = decode_sentence(sess, model, data_reader, "blablahblah went to whatsit", ngram_model=ngram_model)

adj prob of 's is 7.99116807256, orig prob is 7.99108314514
adj prob of me is 8.13354595829, orig prob is 8.13318920135
adj prob of blablahblah is 26.1398732867, orig prob is 25.3398532867
adj prob of 'll is 8.18633164032, orig prob is 8.18625164032
adj prob of even is 10.0410826728, orig prob is 10.0409164429
adj prob of the is 16.1783909647, orig prob is 16.1751861572
adj prob of a is 15.1742242437, orig prob is 15.1728105545
adj prob of an is 13.0255385773, orig prob is 13.0253181458
adj prob of one is 9.22195887935, orig prob is 9.22158241272
adj prob of his is 9.48717178304, orig prob is 9.48693656921
Using token blablahblah
adj prob of an is 10.1539954053, orig prob is 10.1539154053
adj prob of drove is 10.3940276013, orig prob is 10.3939476013
adj prob of a is 10.324884306, orig prob is 10.324804306
adj prob of covered is 10.5649632321, orig prob is 10.5648832321
adj prob of blablahblah is 10.9269086185, orig prob is 10.7668886185
adj prob of the is 10.6831196652, orig prob is 1

In [9]:
blah = decode_str("do you have book")

Input: do you have book
Output: do you have a book



In [8]:
def decode_str(s):
    return decode_sentence(sess, model, data_reader, s)

In [9]:

def decode_sentence(sess, model, data_reader, sentence, verbose=True):
    """Used with InteractiveSession in an IPython notebook."""
    return next(decode(sess, model, data_reader, [sentence.split()], verbose))

In [None]:
# Dropout and replacement rates of 0.9

In [17]:
# Build corpus and decode hypotheses.
baseline_hypotheses = defaultdict(list)  # The model's input
model_hypotheses = defaultdict(list)  # The actual model's predictions
targets = defaultdict(list)
blah = 0

for source, target in data_reader.read_samples_by_string(test_path):
    
    matching_buckets = [i for i, bucket in enumerate(model.buckets) if len(source) < bucket[0]]
    if not matching_buckets:
        continue
    bucket_id = matching_buckets[0]
    
    model_hypotheses[bucket_id].append(next(decode(sess, model, data_reader, [source], verbose=False)))
    
    # Replace out of vocab words with "UNK" in the baseline hypothesis to make it a little fairer.
    baseline_hypothesis = [word if word in data_reader.token_to_id else MovieDialogReader.UNKNOWN_TOKEN
                           for word in source]
    baseline_hypotheses[bucket_id].append(baseline_hypothesis)
    
    # nltk.corpus_bleu expects a list of one or more reference tranlsations per sample,
    # so we wrap the target list in another list here.
    targets[bucket_id].append([target])
    
#     blah += 1
#     if blah > 10:
#         break

In [18]:
for bucket_id in targets.keys():
    baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu(targets[bucket_id], baseline_hypotheses[bucket_id])
    model_bleu_score = nltk.translate.bleu_score.corpus_bleu(targets[bucket_id], model_hypotheses[bucket_id])
    print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id]))
    print("\tBaseline BLEU = {}\n\tModel BLEU = {}".format(baseline_bleu_score, model_bleu_score))

Bucket 0: (10, 10)
	Baseline BLEU = 0.671113372091
	Model BLEU = 0.662576018437
Bucket 1: (15, 15)
	Baseline BLEU = 0.724487231051
	Model BLEU = 0.695375017816
Bucket 2: (20, 20)
	Baseline BLEU = 0.750923521318
	Model BLEU = 0.72081720851
Bucket 3: (40, 40)
	Baseline BLEU = 0.77119960248
	Model BLEU = 0.712245139018
