In [1]:
import tensorflow as tf
import flag as fg
import nltk
import os
import json
import numpy as np
import math
import random
import itertools
from tqdm import tqdm
from pprint import pprint
from functools import reduce
from operator import mul
from utils.helper import get_initializer, dropout, conv1d, multi_conv1d
from utils.helper import flatten, reconstruct, linear, highway_layer, highway_network, mask, exp_mask, softmax, grouper
from utils.utils import index
from utils.read_data import DataSet, read_data, get_squad_data_filter, update_config
from utils.evaluation_utils import compare1, compare2, _get1, _get2, func_span_f1, get_best_span

In [8]:
config = fg.main(_)
config.model_name = 'model_60'
config.out_dir = os.path.join(config.out_base_dir, config.model_name, str(config.run_id).zfill(2))

assert config.load or config.mode == 'train', "config.load must be True if not training"
if not config.load and os.path.exists(config.out_dir):
    shutil.rmtree(config.out_dir)

config.save_dir = os.path.join(config.out_dir, "save")
config.log_dir = os.path.join(config.out_dir, "log")
config.eval_dir = os.path.join(config.out_dir, "eval")
config.answer_dir = os.path.join(config.out_dir, "answer")
if not os.path.exists(config.out_dir):
    os.makedirs(config.out_dir)
if not os.path.exists(config.save_dir):
    os.mkdir(config.save_dir)
if not os.path.exists(config.log_dir):
    os.mkdir(config.log_dir)
if not os.path.exists(config.answer_dir):
    os.mkdir(config.answer_dir)
if not os.path.exists(config.eval_dir):
    os.mkdir(config.eval_dir)

data_filter = get_squad_data_filter(config)

def filter_pos_x(pos_data):
    for i, para in enumerate(pos_data):
        for j, tt in enumerate(para):
            for k, text in enumerate(tt):
                if(pos_data[i][j][k]=='' or pos_data[i][j][k]=='$'or pos_data[i][j][k]=='PDT' or pos_data[i][j][k]=='WP$' or pos_data[i][j][k]=="SYM" or pos_data[i][j][k]=='LS' or pos_data[i][j][k]=='#' or pos_data[i][j][k]=='UH'):
                    pos_data[i][j][k]='OTHER'
    return pos_data

def filter_pos_q(pos_q):
    for i, para in enumerate(pos_q):
        for k, text in enumerate(para):
            if(pos_q[i][k]=='' or pos_q[i][k]=='$'or pos_q[i][k]=='PDT' or pos_q[i][k]=='WP$' or pos_q[i][k]=="SYM" or pos_q[i][k]=='LS' or pos_q[i][k]=='#' or pos_q[i][k]=='UH'):
                pos_q[i][k]='OTHER'
    return pos_q

train_data = read_data(config, 'train', False, data_filter=data_filter)
dev_data = read_data(config, 'dev', False, data_filter=data_filter)

with open('data/squad/pos_x_train.json', 'r') as fh:
        pos_x_train = json.load(fh)

    with open('data/squad/pos_x_dev.json', 'r') as fh:
        pos_x_dev = json.load(fh)

    pos_x_train = filter_pos_x(pos_x_train)
    pos_x_dev = filter_pos_x(pos_x_dev)

    with open('data/squad/pos_q_train.json', 'r') as fh:
        pos_q_train = json.load(fh)

    with open('data/squad/pos_q_dev.json', 'r') as fh:
        pos_q_dev = json.load(fh)

    pos_q_train = filter_pos_q(pos_q_train)
    pos_q_dev = filter_pos_q(pos_q_dev)

    with open("data/squad/pos_vocab.json", 'r') as f:
        pos2int = json.load(f)
    
    with open("data/squad/pos_embd.json", 'r') as f:
        pos_embd = json.load(f)

    train_data.shared['pos_x'] = pos_x_train
    dev_data.shared['pos_x'] = pos_x_dev
    
    train_data.shared['pos_q'] = pos_q_train
    dev_data.shared['pos_q'] = pos_q_dev

update_config(config, [train_data, dev_data])

word2vec_dict = train_data.shared['lower_word2vec'] if config.lower_word else train_data.shared['word2vec']
word2idx_dict = train_data.shared['word2idx']

idx2vec_dict = {word2idx_dict[word]: vec for word, vec in word2vec_dict.items() if word in word2idx_dict}
emb_mat = np.array([idx2vec_dict[idx] if idx in idx2vec_dict
                    else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size))
                    for idx in range(config.word_vocab_size)])

config.emb_mat = emb_mat
config.pos_emb_mat = pos_embd
config.pos_vocab_size = len(pos2int)
config.pos_emb_size = len(pos_embd[0])
config.pos2int = pos2int

In [5]:
config.batch_size = 60

# Context and Ques Parameters
N = config.batch_size
M = config.max_num_sents
JX = config.max_sent_size
JQ = config.max_ques_size
VW = config.word_vocab_size
VC = config.char_vocab_size
W = config.max_word_size
d =  config.hidden_size
dc = config.char_emb_size
dw = config.word_emb_size
dco = config.char_out_size
VP = len(pos2int)
dp = len(pos_embd[0])


def run_bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length, scope):
    flat_input = flatten(inputs, 2)  
    (flat_fw_outputs, flat_bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_input, sequence_length=sequence_length, dtype='float', scope=scope)
    fw = reconstruct(flat_fw_outputs, inputs, 2)
    bw = reconstruct(flat_bw_outputs, inputs, 2)
    return fw, bw

with tf.device('/device:GPU:3'):

    # Placeholders

    x = tf.placeholder('int32', [N, None, None], name='x')
    cx = tf.placeholder('int32', [N, None, None, W], name='cx')
    x_mask = tf.placeholder('bool', [N, None, None], name='x_mask')
    q = tf.placeholder('int32', [N, None], name='q')
    cq = tf.placeholder('int32', [N, None, W], name='cq')
    q_mask = tf.placeholder('bool', [N, None], name='q_mask')
    y1 = tf.placeholder('bool', [N, None, None], name='y1')
    y2 = tf.placeholder('bool', [N, None, None], name='y2')
    is_train = tf.placeholder('bool', [], name='is_train')
    new_emb_mat = tf.placeholder('float', [None, config.word_emb_size], name='new_emb_mat')
    input_keep_prob = tf.cond(is_train,lambda:config.input_keep_prob, lambda:tf.constant(1.0))

    global_step = tf.get_variable('global_step', shape=[], dtype='int32', initializer=tf.constant_initializer(0), trainable=False)
    tensor_dict = {}

    with tf.variable_scope("embedding_layer"):
        if config.use_char_emb:
            with tf.variable_scope("char"):

                char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float')

                Acx = tf.nn.embedding_lookup(char_emb_mat, cx)  # [N, M, JX, W, dc]
                Acq = tf.nn.embedding_lookup(char_emb_mat, cq)  # [N, JQ, W, dc]
                Acx = tf.reshape(Acx, [-1, JX, W, dc])
                Acq = tf.reshape(Acq, [-1, JQ, W, dc])

                filter_sizes = list(map(int, config.out_channel_dims.split(',')))
                heights = list(map(int, config.filter_heights.split(',')))

                with tf.variable_scope("conv"):
                    xx = multi_conv1d(Acx, filter_sizes, heights, "VALID", is_train, config.keep_prob, scope="xx")
                    tf.get_variable_scope().reuse_variables()
                    qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", is_train, config.keep_prob, scope="xx")

                    xx = tf.reshape(xx, [-1, M, JX, dco])
                    qq = tf.reshape(qq, [-1, JQ, dco])

        if config.use_pos_emb:
            with tf.name_scope("pos"):

                if config.mode == 'train':
                    pos_emb_mat = tf.get_variable("pos_emb_mat", dtype='float', shape=[VP, dP], initializer=get_initializer(config.pos_emb_mat))
                else:
                    pos_emb_mat = tf.get_variable("pos_emb_mat", shape=[VP, dP], dtype='float')


                Px = tf.nn.embedding_lookup(pos_emb_mat, x)  
                Pq = tf.nn.embedding_lookup(pos_emb_mat, q)  

                
        if config.use_word_emb:
            with tf.name_scope("word"):

                if config.mode == 'train':
                    word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat))
                else:
                    word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float')


                word_emb_mat = tf.concat([word_emb_mat, new_emb_mat], 0)

                Ax = tf.nn.embedding_lookup(word_emb_mat, x)  # [N, M, JX, d]
                Aq = tf.nn.embedding_lookup(word_emb_mat, q)  # [N, JQ, d]

                tensor_dict['x'] = Ax
                tensor_dict['q'] = Aq

            if config.use_char_emb:
                xx = tf.concat([xx, Ax, Px], 3)  # [N, M, JX, di]
                qq = tf.concat([qq, Aq, Pq], 2)  # [N, JQ, di]
            else:
                xx = Ax
                qq = Aq

    with tf.variable_scope("highway_network_layer"):
        xx = highway_network(xx, config.highway_num_layers, is_train=is_train)
        tf.get_variable_scope().reuse_variables()
        qq = highway_network(qq, config.highway_num_layers, is_train=is_train)

        tensor_dict['xx'] = xx
        tensor_dict['qq'] = qq

    x_len = tf.reduce_sum(tf.cast(x_mask, 'int32'), 2)  # [N, M]
    q_len = tf.reduce_sum(tf.cast(q_mask, 'int32'), 1)  # [N]

    flat_len_q = None if q_len is None else tf.cast(flatten(q_len, 0), 'int64')
    flat_len_x = None if x_len is None else tf.cast(flatten(x_len, 0), 'int64')

    with tf.variable_scope("contextual_layer"):
        cell_fw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        cell_bw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        
        d_cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw, input_keep_prob=input_keep_prob)
        d_cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw, input_keep_prob=input_keep_prob)

        fw_u, bw_u = run_bidirectional_dynamic_rnn(d_cell_fw, d_cell_bw, qq, flat_len_q, 'lstm')
          
        u = tf.concat([fw_u, bw_u], 2)

        tf.get_variable_scope().reuse_variables()

        fw_h, bw_h = run_bidirectional_dynamic_rnn(cell_fw, cell_bw, xx, flat_len_x, 'lstm')

        h = tf.concat([fw_h, bw_h], 3)

        tensor_dict['u'] = u
        tensor_dict['h'] = h

    with tf.variable_scope("attention_layer"):
        h_aug = tf.tile(tf.expand_dims(h, 3), [1, 1, 1, JQ, 1])
        u_aug = tf.tile(tf.expand_dims(tf.expand_dims(u, 1), 1), [1, M, JX, 1, 1])
        h_mask_aug = tf.tile(tf.expand_dims(x_mask, 3), [1, 1, 1, JQ])
        u_mask_aug = tf.tile(tf.expand_dims(tf.expand_dims(q_mask, 1), 1), [1, M, JX, 1])
        hu_mask = h_mask_aug & u_mask_aug

        h_u = h_aug * u_aug

        with tf.variable_scope("similarity"):
            sim = linear([tf.concat([h_aug, u_aug, h_u], -1)], 1, is_train=is_train, scope="sim")
            sim = tf.squeeze(sim, [len(sim.get_shape().as_list())-1])
            sim = exp_mask(sim, hu_mask)

            # Tensor Dict
            a_u = tf.nn.softmax(sim)  
            a_h = tf.nn.softmax(tf.reduce_max(sim, 3))
            tensor_dict['a_u'] = a_u
            tensor_dict['a_h'] = a_h

        with tf.variable_scope("context_2_query"):
            a = softmax(sim)
            rank_u = len(u_aug.get_shape().as_list())
            u_a = tf.reduce_sum(tf.expand_dims(a, -1) * u_aug, rank_u-2)

        with tf.variable_scope("query_2_context"):
            b = softmax(tf.reduce_max(sim, 3))
            rank_h = len(h.get_shape().as_list())
            h_a = tf.reduce_sum(tf.expand_dims(b, -1) * h, rank_h-2)
            h_a = tf.tile(tf.expand_dims(h_a, 2), [1, 1, JX, 1])

        with tf.variable_scope("final"):
            g = tf.concat([h, u_a, h * u_a, h * h_a], 3)


    with tf.variable_scope("modeling_layer"):
        cellm1_fw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        cellm1_bw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        d_cellm1_fw = tf.contrib.rnn.DropoutWrapper(cellm1_fw, input_keep_prob=input_keep_prob)
        d_cellm1_bw = tf.contrib.rnn.DropoutWrapper(cellm1_bw, input_keep_prob=input_keep_prob)
        
        fw_g0, bw_g0 = run_bidirectional_dynamic_rnn(d_cellm1_fw, d_cellm1_bw, g, flat_len_x, 'g0')
        
        g0 = tf.concat([fw_g0, bw_g0], 3)

        cellm2_fw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        cellm2_bw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        d_cellm2_fw = tf.contrib.rnn.DropoutWrapper(cellm2_fw, input_keep_prob=input_keep_prob)
        d_cellm2_bw = tf.contrib.rnn.DropoutWrapper(cellm2_bw, input_keep_prob=input_keep_prob)
        
        fw_g1, bw_g1 = run_bidirectional_dynamic_rnn(d_cellm2_fw, d_cellm2_bw, g0, flat_len_x, 'g1')

        g1 = tf.concat([fw_g1, bw_g1], 3)

    with tf.variable_scope("output_layer"):
        logits1 = linear([tf.concat([g1, g], -1)], 1, input_keep_prob=config.input_keep_prob, is_train=is_train, scope="logits1")
        logits1 = tf.squeeze(logits1, [len(logits1.get_shape().as_list())-1])
        logits1 = exp_mask(logits1, x_mask)

        a = softmax(tf.reshape(logits1, [N, M * JX]))
        g1_reshaped = tf.reshape(g1, [N, M * JX, 2 * d])
        rank_g1 = len(g1_reshaped.get_shape().as_list())
        a1i = tf.reduce_sum(tf.expand_dims(a, -1) * g1_reshaped, rank_g1-2)
        a1i = tf.tile(tf.expand_dims(tf.expand_dims(a1i, 1), 1), [1, M, JX, 1])

        g2_input = tf.concat([g, g1, a1i, g1 * a1i], 3)
        
        cello_fw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        cello_bw = tf.contrib.rnn.BasicLSTMCell(d,state_is_tuple=True)
        d_cello_fw = tf.contrib.rnn.DropoutWrapper(cello_fw, input_keep_prob=input_keep_prob)
        d_cello_bw = tf.contrib.rnn.DropoutWrapper(cello_bw, input_keep_prob=input_keep_prob)
        
        fw_g2, bw_g2 = run_bidirectional_dynamic_rnn(d_cello_fw, d_cello_bw, g2_input, flat_len_x, 'g2')

        g2 = tf.concat([fw_g2, bw_g2], 3)

        logits2 = linear([tf.concat([g2, g], -1)], 1, input_keep_prob=config.input_keep_prob, is_train=is_train, scope="logits2")
        logits2 = tf.squeeze(logits2, [len(logits2.get_shape().as_list())-1])
        logits2 = exp_mask(logits2, x_mask)

        logits1 = tf.reshape(logits1, [-1, M * JX])
        flat_yp1 = tf.nn.softmax(logits1) 
        yp1 = tf.reshape(flat_yp1, [-1, M, JX])
        logits2 = tf.reshape(logits2, [-1, M * JX])
        flat_yp2 = tf.nn.softmax(logits2)
        yp2 = tf.reshape(flat_yp2, [-1, M, JX])

        tensor_dict['g1'] = g1
        tensor_dict['g2'] = g2


    #Loss 
    loss_mask = tf.reduce_max(tf.cast(q_mask, 'float'), 1)
    losses = tf.nn.softmax_cross_entropy_with_logits(logits=logits1, labels=tf.cast(tf.reshape(y1, [-1, M * JX]), 'float'))
    ce_loss1 = tf.reduce_mean(loss_mask * losses)
    ce_loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits2, labels=tf.cast(tf.reshape(y2, [-1, M * JX]), 'float')))
    tf.add_to_collection('losses', ce_loss1)
    tf.add_to_collection("losses", ce_loss2)

    loss = tf.add_n(tf.get_collection('losses'), name='loss')
    tf.summary.scalar(loss.op.name, loss)
    tf.add_to_collection('ema/scalar', loss)

    variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
    for var in variables:
        tensor_dict[var.name] = var

    var_ema = tf.train.ExponentialMovingAverage(config.var_decay)
    ema_op = var_ema.apply(tf.trainable_variables())

    if config.mode == 'train':
        ema = tf.train.ExponentialMovingAverage(config.decay)
        ema_op = ema.apply(tf.get_collection("ema/scalar"))

        for var in tf.get_collection("ema/scalar"):
            ema_var = ema.average(var)
            tf.summary.scalar(ema_var.op.name, ema_var)
        for var in tf.get_collection("ema/vector"):
            ema_var = ema.average(var)
            tf.summary.histogram(ema_var.op.name, ema_var)

    with tf.control_dependencies([ema_op]):
        loss = tf.identity(loss)

    summary = tf.summary.merge_all()
    summary = tf.summary.merge(tf.get_collection("summaries"))

    optimizer = tf.train.AdamOptimizer(config.init_lr)
    grads = optimizer.compute_gradients(loss)
    train_op = optimizer.apply_gradients(grads, global_step=global_step)

In [56]:
def get_feed(batch):
        assert isinstance(batch, DataSet)
        temp_x = np.zeros([N, M, JX], dtype='int32')
        temp_pos_x = pos2int['OTHER']*np.ones([N, M, JX], dtype='int32')
        temp_pos_q = pos2int['OTHER']*np.ones([N, JQ], dtype='int32')

        temp_cx = np.zeros([N, M, JX, W], dtype='int32')
        temp_x_mask = np.zeros([N, M, JX], dtype='bool')
        temp_q = np.zeros([N, JQ], dtype='int32')
        temp_cq = np.zeros([N, JQ, W], dtype='int32')
        temp_q_mask = np.zeros([N, JQ], dtype='bool')
        temp_new_emb_mat = batch.shared['new_emb_mat']
        
        X = batch.data['x']
        CX = batch.data['cx']
        
        temp_y1 = np.zeros([N, M, JX], dtype='bool')
        temp_y2 = np.zeros([N, M, JX], dtype='bool')
        
        for i in range(len(batch.data['q'])):
            idx = batch.data['idxs'][i]
            x1, x2 = batch.data['*x'][i]
            pos_tags = [batch.shared['pos_x'][x1][x2]]
            for j, pp in enumerate(pos_tags):
                for k, ppp in enumerate(pp):
                    if k == config.max_sent_size:
                        break
                    temp_pos_x[i, j, k] = pos2int[pos_tags[j][k]]

        for i in range(len(batch.data['q'])):
            idx = batch.data['idxs'][i]
            pos_tag = batch.shared['pos_q'][idx]
            for j, pp in enumerate(pos_tag):
                temp_pos_q[i,j] = pos2int[pp]
            
        for i, (xi, cxi, yi) in enumerate(zip(X, CX, batch.data['y'])):
            start_idx, stop_idx = random.choice(yi)
            j, k = start_idx
            j2, k2 = stop_idx
            if config.single:
                X[i] = [xi[j]]
                CX[i] = [cxi[j]]
                j, j2 = 0, 0
            if config.squash:
                offset = sum(map(len, xi[:j]))
                j, k = 0, k + offset
                offset = sum(map(len, xi[:j2]))
                j2, k2 = 0, k2 + offset
            temp_y1[i, j, k] = True
            temp_y2[i, j2, k2-1] = True

        def _get_word(word):
            d = batch.shared['word2idx']
            for each in (word, word.lower(), word.capitalize(), word.upper()):
                if each in d:
                    return d[each]
            if config.use_glove_for_unk:
                d2 = batch.shared['new_word2idx']
                for each in (word, word.lower(), word.capitalize(), word.upper()):
                    if each in d2:
                        return d2[each] + len(d)
            return 1

        def _get_char(char):
            d = batch.shared['char2idx']
            if char in d:
                return d[char]
            return 1
        
        for i, xi in enumerate(X):
            if config.squash:
                xi = [list(itertools.chain(*xi))]
            for j, xij in enumerate(xi):
                if j == config.max_num_sents:
                    break
                for k, xijk in enumerate(xij):
                    if k == config.max_sent_size:
                        break
                    each = _get_word(xijk)
                    assert isinstance(each, int), each
                    temp_x[i, j, k] = each
                    temp_x_mask[i, j, k] = True

        for i, cxi in enumerate(CX):
            if config.squash:
                cxi = [list(itertools.chain(*cxi))]
            for j, cxij in enumerate(cxi):
                if j == config.max_num_sents:
                    break
                for k, cxijk in enumerate(cxij):
                    if k == config.max_sent_size:
                        break
                    for l, cxijkl in enumerate(cxijk):
                        if l == config.max_word_size:
                            break
                        temp_cx[i, j, k, l] = _get_char(cxijkl)

        for i, qi in enumerate(batch.data['q']):
            for j, qij in enumerate(qi):
                temp_q[i, j] = _get_word(qij)
                temp_q_mask[i, j] = True

        for i, cqi in enumerate(batch.data['cq']):
            for j, cqij in enumerate(cqi):
                for k, cqijk in enumerate(cqij):
                    temp_cq[i, j, k] = _get_char(cqijk)
                    if k + 1 == config.max_word_size:
                        break
                        
         

        return temp_x, temp_pos, temp_cx, temp_x_mask, temp_q, temp_cq, temp_q_mask, temp_y1, temp_y2, temp_new_emb_mat
def get_feed_dict(data_set, is_train_cond):

    temp_x, temp_pos, temp_cx, temp_x_mask, temp_q, temp_cq, temp_q_mask, temp_y1, temp_y2, temp_new_emb_mat = get_feed(data_set)
        
    feed_dict = {
        x: temp_x,
        pos: temp_pos,
        cx: temp_cx,
        x_mask: temp_x_mask,
        q: temp_q,
        cq: temp_cq,
        q_mask: temp_q_mask,
        y1: temp_y1,
        y2: temp_y2,
        is_train: is_train_cond,
        new_emb_mat: temp_new_emb_mat
        }
    return feed_dict

def train_step(batch, get_summary=False):
    _, data_set = batch

    feed_dict = get_feed_dict(data_set, True)
    if get_summary:
        temp_loss, temp_summary, temp_train_op = \
            sess.run([loss, summary, train_op], feed_dict=feed_dict)
    else:
        temp_loss, temp_train_op = sess.run([loss, train_op], feed_dict=feed_dict)
        temp_summary = None
    return temp_loss, temp_summary, temp_train_op

def eval_step(batch):
    _, data_set = batch
        
    feed_dict = get_feed_dict(data_set, False)
    
    temp_global_step, temp_yp1, temp_yp2, temp_loss, vals = sess.run([global_step, yp1, yp2, loss,list(tensor_dict.values())], feed_dict=feed_dict)
    
    y = data_set.data['y']
    
    temp_yp1, temp_yp2 = temp_yp1[:data_set.num_examples], temp_yp2[:data_set.num_examples]
    
    spans, scores = zip(*[get_best_span(ypi, yp2i) for ypi, yp2i in zip(temp_yp1, temp_yp2)])
    
    id2answer_dict = {id_: _get2(context, xi, span)
                      for id_, xi, span, context in zip(data_set.data['ids'], data_set.data['x'], spans, data_set.data['p'])}
    id2score_dict = {id_: score for id_, score in zip(data_set.data['ids'], scores)}
    id2answer_dict['scores'] = id2score_dict
    correct = [compare2(yi, span) for yi, span in zip(y, spans)]
    
    f1s = [func_span_f1(yi, span) for yi, span in zip(y, spans)]
    f1 = np.mean(f1s)
    f1_summary = tf.Summary(value=[tf.Summary.Value(tag='{}/f1'.format('train'), simple_value=f1)])
    return temp_loss, temp_yp1, temp_yp2, f1, f1_summary

In [7]:
saver = tf.train.Saver()
save_path = os.path.join(config.save_dir, config.model_name)

# sess = tf.Session()
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())

writer = tf.summary.FileWriter(config.log_dir, graph=tf.get_default_graph())
   

In [8]:
# checkpoint = tf.train.latest_checkpoint(config.save_dir)
# saver.restore(sess, checkpoint)

In [13]:
num_steps = config.num_steps

batch_size = config.batch_size

In [None]:
batches = train_data.get_batches(batch_size, num_batches=num_steps, shuffle=True)

multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=1),
                         data_set.divide(1))) for idxs, data_set in batches)
for batch in tqdm(multi_batches, total=num_steps):
    temp_global_step = sess.run(global_step) + 1  # +1 because all calculations are done after step
    
    get_summary = temp_global_step % config.log_period == 0
    temp_loss, temp_summary, temp_train_op = train_step(batch[0], get_summary=get_summary)
    
    if get_summary:
        print("Steps:{}".format(temp_global_step), ", Loss: {}".format(temp_loss))
        writer.add_summary(temp_summary, temp_global_step)

    # occasional saving
    if temp_global_step % config.save_period == 0:
        saver.save(sess, save_path=save_path)

  0%|          | 0/20000 [00:00<?, ?it/s]

In [59]:
batch_size = config.batch_size

batches = train_data.get_batches(batch_size, num_batches=num_steps, shuffle=True)

multi_batches = (tuple(zip(grouper(idxs, batch_size, shorten=True, num_groups=1),
                         data_set.divide(1))) for idxs, data_set in batches)

for br in multi_batches:
    batch_1 = br

temp_loss, temp_yp1, temp_yp2, f1, f1_summary = eval_step(batch_1[0])