In [2]:
from __future__ import print_function

import os
import time
import numpy as np
import tensorflow as tf
import pandas as pd
from collections import defaultdict

from sklearn.metrics import roc_auc_score, accuracy_score
import nltk

from correct_text import train, decode, decode_sentence, evaluate_accuracy, create_model,\
    DefaultPTBConfig, DefaultMovieDialogConfig
from text_correcter_data_readers import PTBDataReader, MovieDialogReader
from text_correcter_models import InputBiasedLanguageModel

%matplotlib inline

In [3]:
root_data_path = "/Users/atpaino/data/textcorrecter/dialog_corpus"
train_path = os.path.join(root_data_path, "cleaned_dialog_train.txt")
val_path = os.path.join(root_data_path, "cleaned_dialog_val.txt")
test_path = os.path.join(root_data_path, "cleaned_dialog_test.txt")
model_path = os.path.join(root_data_path, "dialog_correcter_model")
config = DefaultMovieDialogConfig()

## Train

In [3]:
data_reader = MovieDialogReader(config, train_path)

In [None]:
train(data_reader, train_path, val_path, model_path)

## Decode sentences

In [4]:
data_reader = MovieDialogReader(config, train_path, dropout_prob=0.25, replacement_prob=0.25, dataset_copies=1)

In [5]:
ngram_model = InputBiasedLanguageModel(data_reader, train_path)

In [5]:
ngram_model.prob("hello", [], ["hello", "friend"])

0.800534625413185

In [6]:
ngram_model.prob("friend", [], ["hello", "friend"])

0.3200131397951014

In [7]:
ngram_model.prob("friend", ["hello"], ["hello", "friend"])

0.8

In [6]:
sess = tf.InteractiveSession()
model = create_model(sess, True, model_path, config=config)

Reading model parameters from /Users/atpaino/data/textcorrecter/dialog_corpus/dialog_correcter_model/translate.ckpt-15000


In [6]:
# Test a sample from the test dataset.
decoded = decode_sentence(sess, model, data_reader, "you have girlfriend", ngram_model=ngram_model)

adj prob of them is 7.26981852495e-08, orig prob is 7.26981852495e-08
adj prob of a is 1.00000099258, orig prob is 9.92584318737e-07
adj prob of some is 3.25696510117e-06, orig prob is 3.25696510117e-06
adj prob of the is 1.00000054286, orig prob is 5.42858572317e-07
adj prob of we is 2.07565299206e-05, orig prob is 2.07565299206e-05
adj prob of i is 1.17047693493e-05, orig prob is 1.17047693493e-05
adj prob of you is 1.99985921383, orig prob is 0.999859213829
adj prob of they is 0.000102192170743, orig prob is 0.000102192170743
adj prob of those is 1.33790749146e-07, orig prob is 1.33790749146e-07
adj prob of have is 1.00000088826, orig prob is 8.8825555622e-07
Using token you
adj prob of kept is 1.59370938491e-06, orig prob is 1.59370938491e-06
adj prob of be is 1.62224534961e-06, orig prob is 1.62224534961e-06
adj prob of 'll is 1.09729231894, orig prob is 0.0972923189402
adj prob of got is 3.11434405376e-06, orig prob is 3.11434405376e-06
adj prob of some is 2.86392605631e-06, orig

In [7]:
decode_sentence(sess, model, data_reader, "kvothe went to market", ngram_model=ngram_model, verbose=False)

['kvothe', 'went', 'to', 'the', 'market']

In [9]:
decode_sentence(sess, model, data_reader, "blablahblah and bladdddd went to market", ngram_model=ngram_model, verbose=False)

['blablahblah', 'and', 'bladdddd', 'went', 'to', 'the', 'market']

In [9]:
decode_sentence(sess, model, data_reader, "do you have book", ngram_model=ngram_model, verbose=False)

Input: do you have book
Output: do you have a book



In [8]:
decode_sentence(sess, model, data_reader, "she did better then him", ngram_model=ngram_model, verbose=False)

['she', 'did', 'better', 'then', 'him']

In [9]:
evaluate_accuracy(sess, model, data_reader, test_path, max_samples=1000)

Bucket 0: (10, 10)
	Baseline BLEU = 0.8165
	Model BLEU = 0.6651
	Baseline Accuracy: 0.8959
	Model Accuracy: 0.6876
Bucket 1: (15, 15)
	Baseline BLEU = 0.8647
	Model BLEU = 0.7013
	Baseline Accuracy: 0.7561
	Model Accuracy: 0.3780
Bucket 2: (20, 20)
	Baseline BLEU = 0.8951
	Model BLEU = 0.7148
	Baseline Accuracy: 0.7736
	Model Accuracy: 0.3679
Bucket 3: (40, 40)
	Baseline BLEU = 0.9072
	Model BLEU = 0.7216
	Baseline Accuracy: 0.5397
	Model Accuracy: 0.1111
