# Model setup (common to training/evaluation)

In [43]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
import tensorflow as tf
import numpy as np
import data_utils
import seq2seq_wrapper
import utils
import batcher
import os
import pickle
import time
import nltk

In [3]:
batch_size = 128
emb_dim = 1024
N_EPOCHS = 10
SAVE_EVERY_N_BATCHES = 100
EVAL_EVERY_N_BATCHES = 10
CKPT_DIR = '/var/tmp/archived_checkpoints/30914f8'
LOGS_DIR = 'summaries'
MODEL_NAME = 'baseline'

ckpt_path = os.path.join(CKPT_DIR, 'model')

## Load data

In [4]:
with open('data/movie_script/metadata.pkl', 'rb') as f:
    metadata = pickle.load(f)
trainX = np.load('data/movie_script/training_q.npy')
trainY = np.load('data/movie_script/training_a.npy')
validX = np.load('data/movie_script/validation_q.npy')
validY = np.load('data/movie_script/validation_a.npy')

train_batch_gen = data_utils.rand_batch_gen(trainX, trainY, batch_size)
val_batch_gen = data_utils.rand_batch_gen(validX, validY, 256)

p_r = np.load('data/p_r.npy')

### Sanity check

In [5]:
for b in [train_batch_gen, val_batch_gen]:
    q, a = b.__next__()
    q = data_utils.decode(sequence=list(q[:, 0]), lookup=metadata['idx2w'], separator=' ')
    a = data_utils.decode(sequence=list(a[:, 0]), lookup=metadata['idx2w'], separator=' ')
    print(q)
    print(a)

how you doin ' , pal ?
i ' m okay , sir .
calm down . you brought it up --
i did not , <person> .


## Set up model

In [6]:
# parameters 
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
xvocab_size = len(metadata['idx2w'])  
yvocab_size = xvocab_size

In [7]:
xseq_len = trainX.shape[-1]
yseq_len = trainY.shape[-1]
vocab_size = len(metadata['idx2w'])

model = seq2seq_wrapper.Seq2Seq(xseq_len=xseq_len,
                               yseq_len=yseq_len,
                               xvocab_size=xvocab_size,
                               yvocab_size=yvocab_size,
                               ckpt_path='ckpt/',
                               emb_dim=emb_dim,
                               lstm_dim=emb_dim,
                               num_layers=3,
                               attention=False
                               )

<log> Building Graph </log>

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()

## Set up logging

In [112]:
train_loss_summary = tf.summary.scalar("train_loss", model.loss)
validation_loss_summary = tf.summary.scalar("validation_loss", model.loss)

timestamp = int(time.time())                                                
run_log_dir = os.path.join(LOGS_DIR, MODEL_NAME + '_' + str(timestamp))                        
os.makedirs(run_log_dir)                                                    
# (this step also writes the graph to the events file so that               
# it shows up in TensorBoard)                                               
summary_writer = tf.summary.FileWriter(run_log_dir, sess.graph, flush_secs=5) 

## Load pretrained word embeddings

In [None]:
seq2seq_wrapper.Seq2Seq.load_embeddings(sess, model, metadata, '/scratch/GoogleNews-vectors-negative300.bin')

# Training

## Graph sanity check

In [113]:
x, y = train_batcher.next_batch()
feed_dict = model.get_feed(x, y)
enc_ip, labels = sess.run([model.enc_ip, model.labels], feed_dict)
enc_ip = np.array(enc_ip)
labels = np.array(labels)
# both enc_ip and labels are sequence_len x batch_size
print(data_utils.ids_to_words(enc_ip[:, 0], metadata['idx2w']))
print(data_utils.ids_to_words(labels[:, 0], metadata['idx2w']))

['<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '.', 'really']
['<GO>', 'this', 'is', 'going', 'to', 'be', 'so', 'good', 'for', 'you', '.', '<EOS>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']


## Run the model

In [115]:
# Initialise these in a separate cell so that we can
# interrupt and resume the training loop
epoch_n = 1
batch_n = 1
train_batcher.reset()
validation_batcher.reset()

In [207]:
feed_dict = model.get_feed(x,y)

In [202]:
try:
    while epoch_n <= N_EPOCHS:
        print("Epoch %d/%d" % (epoch_n, N_EPOCHS))
        while not train_batcher.batches_finished:
            step = (epoch_n - 1) * train_batcher.n_batches + batch_n
            
            start = time.time()
            print("Batch %d/%d" % (batch_n, train_batcher.n_batches))
            x, y = train_batcher.next_batch()
            feed_dict = model.get_feed(x, y)
            ops = [model.train_op, model.loss, train_loss_summary]
            _, loss_v, train_loss_summary_v = sess.run(ops, feed_dict)
            print("Batch loss: %.3f" % loss_v)
            summary_writer.add_summary(train_loss_summary_v, step)
            end = time.time()
            print("Batch took %.2f seconds" % (end - start))
            
            if batch_n % EVAL_EVERY_N_BATCHES == 0:
                val_x, val_y = validation_batcher.next_batch()
                if validation_batcher.batches_finished:
                    validation_batcher.reset()
                feed_dict = model.get_feed(val_x, val_y)
                ops = [model.loss, validation_loss_summary]
                loss_v, validation_loss_summary_v = sess.run(ops, feed_dict)
                print("Validation loss: %.3f" % loss_v)
                summary_writer.add_summary(validation_loss_summary_v, step)
                
            if batch_n % SAVE_EVERY_N_BATCHES == 0:
                print("Saving checkpoint after %d steps..." % step)
                saver.save(sess, ckpt_path, global_step=step)
            batch_n += 1
        train_batcher.reset()
        batch_n = 1
        epoch_n += 1
        
except KeyboardInterrupt:
    print('Interrupted')

Epoch 9/10
Batch 1556/2332
Batch loss: 0.105
Batch took 1.70 seconds
Batch 1557/2332
Batch loss: 0.084
Batch took 1.63 seconds
Batch 1558/2332
Batch loss: 0.102
Batch took 1.55 seconds
Batch 1559/2332
Batch loss: 0.088
Batch took 1.53 seconds
Batch 1560/2332
Batch loss: 0.115
Batch took 1.68 seconds
Validation loss: 0.100
Batch 1561/2332
Batch loss: 0.097
Batch took 1.69 seconds
Batch 1562/2332
Batch loss: 0.117
Batch took 1.56 seconds
Batch 1563/2332
Batch loss: 0.096
Batch took 1.76 seconds
Batch 1564/2332
Batch loss: 0.102
Batch took 1.74 seconds
Batch 1565/2332
Batch loss: 0.109
Batch took 1.54 seconds
Batch 1566/2332
Batch loss: 0.096
Batch took 1.72 seconds
Batch 1567/2332
Batch loss: 0.105
Batch took 1.66 seconds
Batch 1568/2332
Batch loss: 0.072
Batch took 1.66 seconds
Batch 1569/2332
Batch loss: 0.099
Batch took 1.64 seconds
Batch 1570/2332
Batch loss: 0.088
Batch took 1.67 seconds
Validation loss: 0.128
Batch 1571/2332
Batch loss: 0.077
Batch took 1.57 seconds
Batch 1572/2332

# Evaluation

## Load checkpoint

In [9]:
checkpoint_file = tf.train.latest_checkpoint(CKPT_DIR)
print("Restoring checkpoint from %s..." % checkpoint_file)
saver.restore(sess, checkpoint_file)

Restoring checkpoint from /var/tmp/archived_checkpoints/30914f8/baseline.ckpt-50000...


## Get 1000 samples from the validation set

In [11]:
valX = validX[:1000,:]
valY = validY[:1000,:]

## Set p_r (used for diversity-enhanced model)

In [None]:
model.p_r = p_r

## Generate replies

In [13]:
valY_pred = model.predict(sess, valX.T)

In [16]:
replies = []
for ii, oi in zip(valX, valY_pred):
    q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ')
    decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ')
    if decoded not in replies:
        print('q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded)))
        replies.append(decoded)
print(len(replies))

q : [you wish !]; a : [<person> , i ' m sorry .]
q : [they want me to fly back tonight .]; a : [you ' re not going to be a good time ?]
q : [<person> ? !]; a : [<person> , <person> .]
q : [the mental hospital .]; a : [the same time .]
q : [absolutely not . no . that is not going on now .]; a : [i ' m not going to be a little bit .]
q : [oh my god , you already did it . was it amazing ?]; a : [i don ' t know .]
q : [yeah .]; a : [<person> ' t you know ?]
q : [<unk> . they have their own label that ' s just outstanding .]; a : [i ' m sorry .]
q : [wan na play ?]; a : [no , no .]
q : [<person> shit . <person> was two years ago ? i guess so . she was <unk> hot , though .]; a : [<person> ' s been a little bit .]
q : [what ?]; a : [i ' m going to be a little bit .]
q : [that ' s bullshit .]; a : [no , i ' m not .]
q : [yeah . i just thought this might change things . i hoped . ugh .]; a : [<person> ? what ' s the matter ?]
q : [i was afraid .]; a : [you ' re not going to be a good time .]
q 

In [17]:
with open('baseline_1000_answers.txt', 'w') as file:
    for ii, oi in zip(valX, valY_pred):
        q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ')
        decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ')
        file.write('q : [{0}]; a : [{1}]'.format(q, ' '.join(decoded)) + '\n')

## Evaluate perplexity

In [23]:
logits = model.predict(sess, valX.T, Y=valY.T, argmax=False)
word_probabilities = utils.softmax(logits, axis=-1)
sentence_perplexities = utils.perplexity(word_probabilities, reference=valY.T)

  if Y == None:
  logp = np.log2(word_probabilities)[reference[i,:] != pad_id]


IndexError: index 20 is out of bounds for axis 0 with size 20

In [30]:
print(word_probabilities.shape)
print(valY.shape)
print(len(sentence_perplexities))

(1000, 20, 10002)
(1000, 20)
1000


In [28]:
sentence_perplexities = utils.perplexity(word_probabilities, reference=valY)

In [31]:
# print a selection of 5
print(sentence_perplexities[0:5])

[48.31062117331394, 835.75975407514022, 13.987229119865045, 32.334449339249893, 109.80054101722652]


In [32]:
with open('baseline_1000_perplexity.txt', 'w') as file:
    for perp in sentence_perplexities:
        file.write(str(perp)+ '\n')

## Evaluate BLEU score

In [38]:
replies = []
for ii, oi in zip(valX, valY_pred):
    q = data_utils.decode(sequence=ii, lookup=metadata['idx2w'], separator=' ')
    decoded = data_utils.decode(sequence=oi, lookup=metadata['idx2w'], separator=' ').split(' ')
    replies.append(decoded)
print(len(replies))

1000


In [41]:
bleu_score = []
for reference, reply in zip(valY, replies):
    reference = data_utils.decode(reference, metadata['idx2w'])
    bleu_score.append(nltk.translate.bleu_score.sentence_bleu(reference, reply))

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


In [40]:
print(len(bleu_score))

1000


In [42]:
with open('baseline_1000_bleu.txt', 'w') as file:
    for bleu in bleu_score:
        file.write(str(bleu)+ '\n')

## Evaluate 'max branching' score

In [47]:
max_branch_scores = utils.max_branching_score(valY_pred.tolist())

In [50]:
with open('baseline_1000_branch.txt', 'w') as file:
    for branch in max_branch_scores:
        file.write(str(branch)+ '\n')

## Generate p_r

In [None]:
p_r = np.full(xvocab_size, 1e-12)

### ...using saved answers

In [None]:
sentences = []
sentences_in_words = []
with open('attention_1000_answers.txt', 'r') as f:
    for line in f:
        sentence = []
        sentences_in_words.append(((line.split('[')[-1]).split(']')[0]))
        for word in ((line.split('[')[-1]).split(']')[0]).split(' ')[:-1]:
            try:sentence.append(metadata['w2idx'][word])
            except: pass
        sentences.append(sentence)

In [None]:
for sentence in sentences:
    for word_id in sentence:
        p_r[word_id] += 1.0

### ...using generated answers

In [None]:
for i in range(valY_pred.shape[0]):
    for j in range(valY_pred.shape[1]):
        word_id = valY_pred[i,j]
        if word_id != 0:
            p_r[word_id] += 1.0

### Normalise and save

In [None]:
p_r = p_r / np.sum(p_r)

In [None]:
np.save('p_r.npy', p_r)