In [1]:
from __future__ import print_function

import os
import time
import numpy as np
import tensorflow as tf
import pandas as pd
from collections import defaultdict

from sklearn.metrics import roc_auc_score, accuracy_score
import nltk

from correct_text import train, decode, decode_sentence, evaluate_accuracy, create_model,\
    DefaultPTBConfig, DefaultMovieDialogConfig
from text_correcter_data_readers import PTBDataReader, MovieDialogReader
from text_correcter_models import InputBiasedLanguageModel

%matplotlib inline

In [2]:
root_data_path = "/Users/atpaino/data/textcorrecter/dialog_corpus"
train_path = os.path.join(root_data_path, "cleaned_dialog_train.txt")
val_path = os.path.join(root_data_path, "cleaned_dialog_val.txt")
test_path = os.path.join(root_data_path, "cleaned_dialog_test.txt")
model_path = os.path.join(root_data_path, "dialog_correcter_model")
config = DefaultMovieDialogConfig()

## Train

In [3]:
data_reader = MovieDialogReader(config, train_path)

In [4]:
train(data_reader, train_path, val_path, model_path)

Reading data; train = /Users/atpaino/data/textcorrecter/dialog_corpus/cleaned_dialog_train.txt, test = /Users/atpaino/data/textcorrecter/dialog_corpus/cleaned_dialog_val.txt
Creating 2 layers of 512 units.
Reading model parameters from /Users/atpaino/data/textcorrecter/dialog_corpus/dialog_correcter_model/translate.ckpt-15000
Training bucket sizes: [226666, 98064, 56724, 80504]
Total train size: 461958.0
global step 15100 learning rate 0.4049 step-time 4.43 perplexity 1.05
  eval: bucket 0 perplexity 1.04
  eval: bucket 1 perplexity 1.04
  eval: bucket 2 perplexity 1.12
  eval: bucket 3 perplexity 1.67
global step 15200 learning rate 0.4049 step-time 5.28 perplexity 1.12
  eval: bucket 0 perplexity 1.02
  eval: bucket 1 perplexity 1.13
  eval: bucket 2 perplexity 1.15
  eval: bucket 3 perplexity 1.16
global step 15300 learning rate 0.4049 step-time 4.65 perplexity 1.06
  eval: bucket 0 perplexity 1.07
  eval: bucket 1 perplexity 1.05
  eval: bucket 2 perplexity 1.05
  eval: bucket 3 pe

KeyboardInterrupt: 

## Decode sentences

In [3]:
data_reader = MovieDialogReader(config, train_path, dropout_prob=0.25, replacement_prob=0.25, dataset_copies=1)

In [4]:
ngram_model = InputBiasedLanguageModel(data_reader, train_path)

In [5]:
ngram_model.prob("hello", [], ["hello", "friend"])

0.800534625413185

In [6]:
ngram_model.prob("friend", [], ["hello", "friend"])

0.3200131397951014

In [7]:
ngram_model.prob("friend", ["hello"], ["hello", "friend"])

0.8

In [5]:
sess = tf.InteractiveSession()
model = create_model(sess, True, model_path, config=config)

Reading model parameters from /Users/atpaino/data/textcorrecter/dialog_corpus/dialog_correcter_model/translate.ckpt-31300


In [6]:
# Test a sample from the test dataset.
decoded = decode_sentence(sess, model, data_reader, "you have girlfriend", ngram_model=ngram_model)

adj prob of the is 0.0032125145679, orig prob is 7.7070453699e-06
adj prob of have is 0.213768005877, orig prob is 2.38067727309e-07
adj prob of you is 1.81305201706, orig prob is 0.99992454052
adj prob of a is 0.00141589109651, orig prob is 2.20189372158e-06
adj prob of girlfriend is 0.213336618282, orig prob is 6.69496704982e-15
adj prob of an is 0.00022046399046, orig prob is 3.24303144339e-08
Using token you
adj prob of 've is 0.0039547956255, orig prob is 2.80572476186e-05
adj prob of have is 1.78443975162, orig prob is 0.979787528515
adj prob of girlfriend is 0.213333333333, orig prob is 2.83150137722e-14
adj prob of a is 0.00155948821844, orig prob is 2.33249593862e-07
adj prob of 's is 8.33846878109e-05, orig prob is 7.39336147859e-09
adj prob of 'll is 0.0233438322684, orig prob is 0.0201786737889
Using token have
adj prob of the is 0.0273468691474, orig prob is 0.0202672537416
adj prob of girlfriend is 0.807891923189, orig prob is 0.00789192318916
adj prob of a is 0.989217468

In [7]:
decoded = decode_sentence(sess, model, data_reader,
                          "did n't you say that they 're going to develop this revolutionary new thing ...",
                          ngram_model=ngram_model)

adj prob of say is 0.0428358415407, orig prob is 1.21358530863e-11
adj prob of they is 0.0442951800221, orig prob is 7.74678444223e-14
adj prob of develop is 0.042666694946, orig prob is 2.82793788386e-08
adj prob of did is 1.8010124398, orig prob is 0.999998211861
adj prob of going is 0.0426937674941, orig prob is 5.47143812191e-14
adj prob of n't is 0.0426666666701, orig prob is 3.4008301817e-12
adj prob of the is 0.00320486693163, orig prob is 5.94090927564e-08
adj prob of ... is 0.044791211328, orig prob is 4.04085032102e-09
adj prob of thing is 0.042674067052, orig prob is 9.25055143597e-09
adj prob of to is 0.0430280110321, orig prob is 1.41684895153e-13
adj prob of new is 0.0427028011032, orig prob is 1.59016565682e-16
adj prob of you is 0.0557941432159, orig prob is 5.78321938949e-12
adj prob of 're is 0.0426666666667, orig prob is 1.62911550737e-17
adj prob of that is 0.0466816953073, orig prob is 9.34927862125e-16
adj prob of this is 0.0442155200207, orig prob is 6.4474541157

In [7]:
decode_sentence(sess, model, data_reader, "kvothe went to market", ngram_model=ngram_model, verbose=False)

['kvothe', 'went', 'to', 'the', 'market']

In [9]:
decode_sentence(sess, model, data_reader, "blablahblah and bladdddd went to market", ngram_model=ngram_model, verbose=False)

['blablahblah', 'and', 'bladdddd', 'went', 'to', 'market', 'market']

In [9]:
decode_sentence(sess, model, data_reader, "do you have book", ngram_model=ngram_model, verbose=False)

Input: do you have book
Output: do you have a book



In [8]:
decode_sentence(sess, model, data_reader, "she did better then him", ngram_model=ngram_model, verbose=False)

['she', 'did', 'better', 'then', 'him']

In [6]:
errors = evaluate_accuracy(sess, model, data_reader, ngram_model, test_path, max_samples=1000)

Bucket 0: (10, 10)
	Baseline BLEU = 0.8136
	Model BLEU = 0.8238
	Baseline Accuracy: 0.8891
	Model Accuracy: 0.9238
Bucket 1: (15, 15)
	Baseline BLEU = 0.8855
	Model BLEU = 0.8797
	Baseline Accuracy: 0.7927
	Model Accuracy: 0.8598
Bucket 2: (20, 20)
	Baseline BLEU = 0.9057
	Model BLEU = 0.8814
	Baseline Accuracy: 0.8091
	Model Accuracy: 0.8000
Bucket 3: (40, 40)
	Baseline BLEU = 0.9018
	Model BLEU = 0.9030
	Baseline Accuracy: 0.6423
	Model Accuracy: 0.6667


In [13]:
for decoding, target in errors:
    print("Decoding: " + " ".join(decoding))
    print("Target:   " + " ".join(target) + "\n")

Decoding: you 'll beg for mercy in second second .
Target:   you 'll beg for mercy in a second .

Decoding: i 'm dying for a shower . you could use one one too . and we 'd better check that bandage .
Target:   i 'm dying for a shower . you could use one too . and we 'd better check that bandage .

Decoding: listen . understand . i 'm not a military objective , reese . i 'm person person ... you do n't own me .
Target:   listen . understand . i 'm not a military objective , reese . i 'm a person ... you do n't own me .

Decoding: whatever ... they become the hotshot computer guys so they get the job to build el computer grande ... skynet ... for government government . right ?
Target:   whatever ... they become the hotshot computer guys so they get the job to build el computer grande ... skynet ... for the government . right ?

Decoding: did n't you say that they 're going to develop this revolutionary new new thing ...
Target:   did n't you say that they 're going to develop this revol