In [1]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import tensorflow as tf

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

import sys

if not '../' in sys.path: sys.path.append('../')

import pandas as pd

from utils import data_utils
from model_config import config
from ved_varAttnMultiTask import VarSeq2SeqVarAttnMultiTaskModel

Using TensorFlow backend.


In [2]:
if config['experiment'] == 'qgen':
    print('[INFO] Preparing data for experiment: {}'.format(config['experiment']))
    train_data = pd.read_csv(config['data_dir'] + 'df_qgen_train.csv')
    val_data = pd.read_csv(config['data_dir'] + 'df_qgen_val.csv')
    test_data = pd.read_csv(config['data_dir'] + 'df_qgen_test.csv')
    input_sentences = pd.concat([train_data['answer'], val_data['answer'], test_data['answer']])
    output_sentences = pd.concat([train_data['question'], val_data['question'], test_data['question']])
    true_test = test_data['question']
    input_test = test_data['answer']
    filters = '!"#$%&()*+,./:;<=>?@[\\]^`{|}~\t\n'
    w2v_path = config['w2v_dir'] + 'w2vmodel_qgen.pkl'
    
elif config['experiment'] == 'dialogue':
    train_data = pd.read_csv(config['data_dir'] + 'df_dialogue_train.csv')
    val_data = pd.read_csv(config['data_dir'] + 'df_dialogue_val.csv')
    test_data = pd.read_csv(config['data_dir'] + 'df_dialogue_test.csv')
    input_sentences = pd.concat([train_data['line'], val_data['line'], test_data['line']])
    output_sentences = pd.concat([train_data['reply'], val_data['reply'], test_data['reply']])
    true_test = test_data['reply']
    input_test = test_data['line']
    filters = '!"#$%&()*+/:;<=>@[\\]^`{|}~\t\n'
    w2v_path = config['w2v_dir'] + 'w2vmodel_dialogue.pkl'

elif config['experiment'] == 'arc':
        train_data = pd.read_csv(config['data_dir'] + 'df_arc_train.csv')
        val_data = pd.read_csv(config['data_dir'] + 'df_arc_val.csv')
        test_data = pd.read_csv(config['data_dir'] + 'df_arc_test.csv')
        input_sentences = pd.concat([train_data['ProductSent'],
                                    val_data['ProductSent'],
                                    test_data['ProductSent']])
        output_sentences = pd.concat([train_data['ProductPhrase'],
                                      val_data['ProductPhrase'],
                                      test_data['ProductPhrase']])
        categories = pd.concat([train_data['Category'],
                                val_data['Category'],
                                test_data['Category']])
        # Generate labels
        all_categories = categories.unique()
        num_categories = len(all_categories)
        print("total categories: {}".format(num_categories))
        cat_to_index = {}
        index_to_cat = {}
        for idx, cat in enumerate(all_categories):
            cat_to_index[cat] = idx
            index_to_cat[idx] = cat

        true_test = test_data['ProductPhrase']
        input_test = test_data['ProductSent']
        true_cat_test = test_data['Category']
        filters = '!"#$%&()*+/:;<=>@[\\]^`{|}~\t\n'
        w2v_path = config['w2v_dir'] + 'w2vmodel_arc.pkl'

else:
    print('Invalid experiment name specified!')

total categories: 948


In [3]:
from train_discriminator import get_label_vec
print('[INFO] Tokenizing input and output sequences')
x, input_word_index = data_utils.tokenize_sequence(input_sentences, 
                                                                      filters, 
                                                                      config['encoder_num_tokens'], 
                                                                      config['encoder_vocab'])

y, output_word_index = data_utils.tokenize_sequence(output_sentences, 
                                                                        filters, 
                                                                        config['decoder_num_tokens'], 
                                                                        config['decoder_vocab'])
z = [get_label_vec(label_seq, cat_to_index, num_categories) for label_seq in categories]

print('[INFO] Split data into train-validation-test sets')
x_train, y_train, z_train, x_val, y_val, z_val, x_test, y_test, z_test = \
        data_utils.create_data_split_mult(x, y, z, config['experiment'])
encoder_embeddings_matrix = data_utils.create_embedding_matrix(input_word_index, 
                                                               config['embedding_size'], 
                                                               w2v_path)

decoder_embeddings_matrix = data_utils.create_embedding_matrix(output_word_index, 
                                                               config['embedding_size'], 
                                                               w2v_path)

# Re-calculate the vocab size based on the word_idx dictionary
config['encoder_vocab'] = len(input_word_index)
config['decoder_vocab'] = len(output_word_index)

[INFO] Tokenizing input and output sequences
[INFO] Split data into train-validation-test sets


In [4]:
model = VarSeq2SeqVarAttnMultiTaskModel(config, 
                               encoder_embeddings_matrix, 
                               decoder_embeddings_matrix, 
                               input_word_index, 
                               output_word_index)

[INFO] Building Model ...
Instructions for updating:
seq_dim is deprecated, use seq_axis instead
Instructions for updating:
batch_dim is deprecated, use batch_axis instead
Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [5]:
if config['load_checkpoint'] != 0: 
    checkpoint = config['model_checkpoint_dir'] + str(config['load_checkpoint']) + '.ckpt'
else:
    checkpoint = tf.train.get_checkpoint_state(os.path.dirname('models/checkpoint')).model_checkpoint_path

In [6]:
# Predict
preds = model.predict(checkpoint, 
                      x_test, 
                      y_test,
                      z_test,
                      true_test,
                      true_cat_test)

INFO:tensorflow:Restoring parameters from models/var-seq2seq-var-attn-9.ckpt
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.0
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.05000000074505806
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.03999999910593033
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.0

cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.0
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552

cnn_loss: 685.4354248046875, cnn_accuracy: 0.03999999910593033
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.03999999910593033
cnn_loss: 685.4354248046875, cnn_accuracy: 0.029999999329447746
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.019999999552965164
cnn_loss: 685.4354248046875, cnn_accuracy: 0.009999999776482582
cnn_loss: 685.4354248046875, cnn_accuracy: 0.03999999910593033
cnn_loss: 685.4354248046875, cnn_accuracy: 0.0
cnn_loss: 685.4354248046875, cnn_accuracy: 0.050000000745058