In [1]:
import dynet_config
import os
import random
import time

dynet_config.set(mem='11000', autobatch=1, requested_gpus=1)

In [2]:
import dynet as dy
import numpy as np

import itertools

from baseline_load_data import load_questions, VOCAB_SIZE
from contextlib import contextmanager

In [3]:
@contextmanager
def parameters(*params):
    yield tuple(map(lambda x:dy.parameter(x), params))

In [4]:
train_set = load_questions('./data/train.tok.json')
dev_set = load_questions('./data/dev.tok.json')

In [5]:
model = dy.ParameterCollection()
trainer = dy.AdamTrainer(model)

In [6]:
NUM_LAYERS = 2
EMBED_SIZE = 256
HIDDEN_SIZE = 256
ATTENTION_SIZE = 256

In [7]:
embeds = model.add_lookup_parameters((VOCAB_SIZE, EMBED_SIZE))

In [8]:
fwRNN = dy.BiRNNBuilder(NUM_LAYERS, EMBED_SIZE, HIDDEN_SIZE, model, dy.LSTMBuilder)
bwRNN = dy.LSTMBuilder(NUM_LAYERS, HIDDEN_SIZE + EMBED_SIZE, HIDDEN_SIZE, model, dy.LSTMBuilder)
bw_init_input = model.add_parameters((HIDDEN_SIZE + EMBED_SIZE))

In [9]:
att_W_ctx = model.add_parameters((ATTENTION_SIZE, HIDDEN_SIZE))
att_W_h = model.add_parameters((ATTENTION_SIZE, HIDDEN_SIZE))
att_b = model.add_parameters((1, ATTENTION_SIZE))

In [10]:
def calc_attention(ctx_matrix, ctx_att, h):
    with parameters(att_W_h, att_b) as (W, b):
        att_score = dy.transpose(b * dy.tanh(dy.colwise_add(ctx_att, W * h)))
        att_p = dy.softmax(att_score)
        ctx_mixture = ctx_matrix * att_p
        return ctx_mixture, att_p

In [11]:
out_W = model.add_parameters((VOCAB_SIZE, HIDDEN_SIZE))
out_b = model.add_parameters((VOCAB_SIZE))

In [12]:
def loss(x, y):
    x = [embeds[tid] for tid in x]
    ctx_seq = fwRNN.transduce(x)
    ctx_matrix = dy.concatenate_cols(ctx_seq)
    with parameters(att_W_ctx, bw_init_input) as (W, init_input):
        ctx_att = W * ctx_matrix
        current_state = bwRNN.initial_state().add_input(init_input)
    h = current_state.output()
    prev_tid = y[0]
    losses = []
    with parameters(out_W, out_b) as (W, b):
        for next_tid in y[1:]:
            ctx_mixture, _ = calc_attention(ctx_matrix, ctx_att, h)
            current_state = current_state.add_input(dy.concatenate([ctx_mixture, embeds[prev_tid]]))
            h = current_state.output()
            probs = dy.softmax(W * h + b)
            losses.append(dy.pickneglogsoftmax(probs, next_tid))
            prev_tid = next_tid
    return dy.esum(losses)

In [13]:
if os.path.exists('./baseline_bilstm.model'):
    model.populate('./baseline_bilstm.model')

In [14]:
def dev_loss(batch_size):
    total_loss = 0.0
    token_count = 0
    for pos in range(0, len(dev_set), batch_size):
        dy.renew_cg()
        current_batch = dev_set[pos:pos+BATCH_SIZE]
        batch_loss = dy.esum([loss(x, y) for x, y in current_batch]) / len(current_batch)
        total_loss += batch_loss.value()
        token_count += sum(map(len, current_batch))
    print('dev perplexity: %f' % (total_loss / token_count))
    return total_loss

In [15]:
BATCH_SIZE = 16

In [None]:
last_loss = None
for epoch in itertools.count(1):
    print('runing epoch %d...' % epoch)
    random.shuffle(train_set)
    for num_batch, pos in enumerate(range(0, len(train_set), BATCH_SIZE)):
        if num_batch % 500 == 0:
            print(time.ctime())
            total_loss = dev_loss(BATCH_SIZE)
            print('epoch %d batch %d finished' % (epoch, num_batch))
            if last_loss is not None and last_loss < total_loss:
                print('training stoped due to loss increasing on dev.')
                exit(0)
            model.save('./baseline_bilstm.model')
            last_loss = total_loss
            print(time.ctime())
        dy.renew_cg()
        current_batch = train_set[pos:pos+BATCH_SIZE]
        batch_loss = dy.esum([loss(x, y) for x, y in current_batch]) / len(current_batch)
        batch_loss.backward()
        trainer.update()

In [16]:
from vocab import load_vocabs, CHECK_A, CHECK_B, CHECK_C, CHECK_D, CHECK_E

In [17]:
test_set = load_questions('./data/test.tok.json')
tok2id, _ = load_vocabs()
terminal_tids = set([tok2id[CHECK_A], tok2id[CHECK_B], tok2id[CHECK_C], tok2id[CHECK_D], tok2id[CHECK_E]])

In [18]:
def sample(x, y):
    x = [embeds[tid] for tid in x]
    ctx_seq = fwRNN.transduce(x)
    ctx_matrix = dy.concatenate_cols(ctx_seq)
    with parameters(att_W_ctx, bw_init_input) as (W, init_input):
        ctx_att = W * ctx_matrix
        current_state = bwRNN.initial_state().add_input(init_input)
    h = current_state.output()
    prev_tid = y[0]
    losses = []
    with parameters(out_W, out_b) as (W, b):
        for _ in range(len(y) * 2):
            ctx_mixture, _ = calc_attention(ctx_matrix, ctx_att, h)
            current_state = current_state.add_input(dy.concatenate([ctx_mixture, embeds[prev_tid]]))
            h = current_state.output()
            probs = dy.softmax(W * h + b).npvalue()[:,0]
            probs /= probs.sum()
            next_tid = np.random.choice(VOCAB_SIZE, 1, p=probs)[0]
            if next_tid in terminal_tids:
                return next_tid
            prev_tid = next_tid
    options = list(terminal_tids)
    option_probs = np.array([probs[tid] for tid in options])
    option_probs /= option_probs.sum()
    return np.random.choice(options, 1, p=option_probs)[0]

In [20]:
correct = 0
for x, y in test_set[:100]:
    counter = {tid: 0 for tid in terminal_tids}
    for _ in range(100):
        dy.renew_cg()
        counter[sample(x, y)] += 1
    for tid, count in counter.items():
        if tid != y[-1] and counter[tid] > counter[y[-1]]:
            break
    else:
        correct += 1
print(float(correct)/len(test_set))