In [None]:
from __future__ import print_function

import os
import time
import numpy as np
import tensorflow as tf
import pandas as pd
from collections import defaultdict

from sklearn.metrics import roc_auc_score, accuracy_score
import nltk

from correct_text import train, decode, create_model, DefaultPTBConfig, DefaultMovieDialogConfig
from text_correcter_data_readers import PTBDataReader, MovieDialogReader

%matplotlib inline

In [None]:
root_data_path = "/Users/atpaino/data/textcorrecter/dialog_corpus"
train_path = os.path.join(root_data_path, "cleaned_dialog_train.txt")#"ptb.train.txt") 
val_path = os.path.join(root_data_path, "cleaned_dialog_val.txt")#"ptb.valid.txt")
test_path = os.path.join(root_data_path, "cleaned_dialog_test.txt")
model_path = os.path.join(root_data_path, "dialog_correcter_model")
config = DefaultMovieDialogConfig()

## Train

In [None]:
data_reader = MovieDialogReader(config, train_path)

In [None]:
train(data_reader, train_path, val_path, model_path)

## Decode sentences

In [None]:
sess = tf.InteractiveSession()
model = create_model(sess, True, model_path, config=config)

In [None]:
data_reader = MovieDialogReader(config, train_path, dropout_prob=0.9, replacement_prob=0.9, dataset_copies=1)

In [None]:
def dec(s, verbose=True):
    return decode(sess, model, data_reader, [s], verbose=verbose)

In [None]:
# Test a sample from the test dataset.
decoded = decode(sess, model, data_reader,
                 ["you have girlfriend"])

In [None]:
# Dropout and replacement rates of 0.9

In [None]:
# Build corpus and decode hypotheses.
baseline_hypotheses = defaultdict(list)  # The model's input
model_hypotheses = defaultdict(list)  # The actual model's predictions
targets = defaultdict(list)
blah = 0

for source, target in data_reader.read_token_samples(test_path):
    
    matching_buckets = [i for i, bucket in enumerate(model.buckets) if len(source) < bucket[0]]
    if not matching_buckets:
        continue
    bucket_id = matching_buckets[0]
    
    model_hypotheses[bucket_id].append(dec(" ".join(source), verbose=False)[0])
    
    # Replace out of vocab words with "UNK" in the baseline hypothesis to make it a little fairer.
    baseline_hypothesis = [word if word in data_reader.token_to_id else MovieDialogReader.UNKNOWN_TOKEN
                           for word in source]
    baseline_hypotheses[bucket_id].append(baseline_hypothesis)
    
    # nltk.corpus_bleu expects a list of one or more reference tranlsations per sample,
    # so we wrap the target list in another list here.
    targets[bucket_id].append([target])
    
#     blah += 1
#     if blah > 100:
#         break

In [None]:
for bucket_id in targets.keys():
    baseline_bleu_score = nltk.translate.bleu_score.corpus_bleu(targets[bucket_id], baseline_hypotheses[bucket_id])
    model_bleu_score = nltk.translate.bleu_score.corpus_bleu(targets[bucket_id], model_hypotheses[bucket_id])
    print("Bucket {}: {}".format(bucket_id, model.buckets[bucket_id]))
    print("\tBaseline BLEU = {}\n\tModel BLEU = {}".format(baseline_bleu_score, model_bleu_score))