In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import bert
import bert.modeling, bert.tokenization
import transformers

In [3]:
import sys
sys.path.insert(0, "/home/husein/parsing/self-attentive-parser/src")
sys.path.append("/home/husein/parsing/self-attentive-parser")

In [4]:
import argparse
import itertools
import os.path
import time
import shutil
import re
import json

import torch
import torch.optim.lr_scheduler

import numpy as np

import evaluate
import trees_newline as trees
import vocabulary
import nkutil
import parse_nk_tiny_bert as parse_nk
tokens = parse_nk

In [5]:
if parse_nk.use_cuda:
    info = torch.load('models/en_bert_dev=76.79.pt')

In [6]:
parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict'])
bert_model = info['spec']['hparams']['bert_model']

In [7]:
test_treebank = trees.load_trees('test.txt')

In [8]:
import tensorflow as tf

@@#reset_default_graph()
sess = tf.compat.v1.InteractiveSession()
sd = parser.state_dict()

LABEL_VOCAB = [x[0] for x in sorted(parser.label_vocab.indices.items(), key=lambda x: x[1])]
TAG_VOCAB = [x[0] for x in sorted(parser.tag_vocab.indices.items(), key=lambda x: x[1])]

In [9]:
def make_bert(input_ids, word_end_mask):
    # We can derive input_mask from either input_ids or word_end_mask
    input_mask = (1 - tf.compat.v1.cumprod(1 - word_end_mask, axis=-1, reverse=True))
    token_type_ids = tf.compat.v1.zeros_like(input_ids)
    bert_model = make_bert_instance(input_ids, input_mask, token_type_ids)

    bert_features = bert_model.get_sequence_output()
    bert_features_packed = tf.compat.v1.gather(
        tf.compat.v1.reshape(bert_features, [-1, int(bert_features.shape[-1])]),
        tf.compat.v1.to_int32(tf.compat.v1.where(tf.compat.v1.reshape(word_end_mask, (-1,))))[:,0])
    projected_annotations = tf.compat.v1.matmul(
        bert_features_packed,
        tf.compat.v1.constant(sd['project_bert.weight'].cpu().numpy().transpose()))

    # input_mask is over subwords, whereas valid_mask is over words
    sentence_lengths = tf.compat.v1.reduce_sum(word_end_mask, -1)
    valid_mask = (tf.compat.v1.range(tf.compat.v1.reduce_max(sentence_lengths))[None,:] < sentence_lengths[:, None])
    dim_padded = tf.compat.v1.shape(valid_mask)[:2]
    mask_flat = tf.compat.v1.reshape(valid_mask, (-1,))
    dim_flat = tf.compat.v1.shape(mask_flat)[:1]
    nonpad_ids = tf.compat.v1.to_int32(tf.compat.v1.where(mask_flat)[:,0])

    return projected_annotations, nonpad_ids, dim_flat, dim_padded, valid_mask, sentence_lengths

In [10]:
position_table = tf.compat.v1.constant(sd['embedding.position_table'].cpu().numpy(), name="position_table")

In [11]:
def make_layer_norm(input, torch_name, name):
    # TODO(nikita): The epsilon here isn't quite the same as in pytorch
    # The pytorch code adds eps=1e-3 to the standard deviation, while this
    # tensorflow code adds eps=1e-6 to the variance.
    # However, the resulting mismatch in floating-point values does not seem to
    # translate to any noticable changes in the parser's tree output
    mean, variance = tf.compat.v1.nn.moments(input, [1], keep_dims=True)
    return tf.compat.v1.nn.batch_normalization(
        input,
        mean, variance,
        offset=tf.compat.v1.constant(sd[f'{torch_name}.b_2'].cpu().numpy(), name=f"{name}/offset"),
        scale=tf.compat.v1.constant(sd[f'{torch_name}.a_2'].cpu().numpy(), name=f"{name}/scale"),
        variance_epsilon=1e-6)


def make_heads(input, shape_bthf, shape_xtf, torch_name, name):
    res = tf.compat.v1.matmul(input,
        tf.compat.v1.constant(sd[torch_name].cpu().numpy().transpose((1,0,2)).reshape((512, -1)), name=f"{name}/W"))
    res = tf.compat.v1.reshape(res, shape_bthf)
    res = tf.compat.v1.transpose(res, (0,2,1,3)) # batch x num_heads x time x feat
    res = tf.compat.v1.reshape(res, shape_xtf) # _ x time x feat
    return res

def make_attention(input, nonpad_ids, dim_flat, dim_padded, valid_mask, torch_name, name):
    input_flat = tf.compat.v1.scatter_nd(indices=nonpad_ids[:, None], updates=input, shape=tf.compat.v1.concat([dim_flat, tf.compat.v1.shape(input)[1:]], axis=0))
    input_flat_dat, input_flat_pos = tf.compat.v1.split(input_flat, 2, axis=-1)

    shape_bthf = tf.compat.v1.concat([dim_padded, [8, -1]], axis=0)
    shape_bhtf = tf.compat.v1.convert_to_tensor([dim_padded[0], 8, dim_padded[1], -1])
    shape_xtf = tf.compat.v1.convert_to_tensor([dim_padded[0] * 8, dim_padded[1], -1])
    shape_xf = tf.compat.v1.concat([dim_flat, [-1]], axis=0)

    qs1 = make_heads(input_flat_dat, shape_bthf, shape_xtf, f'{torch_name}.w_qs1', f'{name}/q_dat')
    ks1 = make_heads(input_flat_dat, shape_bthf, shape_xtf, f'{torch_name}.w_ks1', f'{name}/k_dat')
    vs1 = make_heads(input_flat_dat, shape_bthf, shape_xtf, f'{torch_name}.w_vs1', f'{name}/v_dat')
    qs2 = make_heads(input_flat_pos, shape_bthf, shape_xtf, f'{torch_name}.w_qs2', f'{name}/q_pos')
    ks2 = make_heads(input_flat_pos, shape_bthf, shape_xtf, f'{torch_name}.w_ks2', f'{name}/k_pos')
    vs2 = make_heads(input_flat_pos, shape_bthf, shape_xtf, f'{torch_name}.w_vs2', f'{name}/v_pos')

    qs = tf.compat.v1.concat([qs1, qs2], axis=-1)
    ks = tf.compat.v1.concat([ks1, ks2], axis=-1)
    attn_logits = tf.compat.v1.matmul(qs, ks, transpose_b=True) / (1024 ** 0.5)

    attn_mask = tf.compat.v1.reshape(tf.compat.v1.tile(valid_mask, [1,8*dim_padded[1]]), tf.compat.v1.shape(attn_logits))
    # TODO(nikita): use tf.compat.v1.where and -float('inf') here?
    attn_logits -= 1e10 * tf.compat.v1.to_float(~attn_mask)

    attn = tf.compat.v1.nn.softmax(attn_logits)

    attended_dat_raw = tf.compat.v1.matmul(attn, vs1)
    attended_dat_flat = tf.compat.v1.reshape(tf.compat.v1.transpose(tf.compat.v1.reshape(attended_dat_raw, shape_bhtf), (0,2,1,3)), shape_xf)
    attended_dat = tf.compat.v1.gather(attended_dat_flat, nonpad_ids)
    attended_pos_raw = tf.compat.v1.matmul(attn, vs2)
    attended_pos_flat = tf.compat.v1.reshape(tf.compat.v1.transpose(tf.compat.v1.reshape(attended_pos_raw, shape_bhtf), (0,2,1,3)), shape_xf)
    attended_pos = tf.compat.v1.gather(attended_pos_flat, nonpad_ids)

    out_dat = tf.compat.v1.matmul(attended_dat, tf.compat.v1.constant(sd[f'{torch_name}.proj1.weight'].cpu().numpy().transpose()))
    out_pos = tf.compat.v1.matmul(attended_pos, tf.compat.v1.constant(sd[f'{torch_name}.proj2.weight'].cpu().numpy().transpose()))

    out = tf.compat.v1.concat([out_dat, out_pos], -1)
    return make_layer_norm(input + out, f'{torch_name}.layer_norm', f'{name}/layer_norm')

def make_dense_relu_dense(input, torch_name, torch_type, name):
    # TODO: use name
    mul1 = tf.compat.v1.matmul(input, tf.compat.v1.constant(sd[f'{torch_name}.w_1{torch_type}.weight'].cpu().numpy().transpose()))
    mul1b = tf.compat.v1.nn.bias_add(mul1, tf.compat.v1.constant(sd[f'{torch_name}.w_1{torch_type}.bias'].cpu().numpy()))
    mul1b = tf.compat.v1.nn.relu(mul1b)
    mul2 = tf.compat.v1.matmul(mul1b, tf.compat.v1.constant(sd[f'{torch_name}.w_2{torch_type}.weight'].cpu().numpy().transpose()))
    mul2b = tf.compat.v1.nn.bias_add(mul2, tf.compat.v1.constant(sd[f'{torch_name}.w_2{torch_type}.bias'].cpu().numpy()))
    return mul2b

def make_ff(input, torch_name, name):
    # TODO: use name
    input_dat, input_pos = tf.compat.v1.split(input, 2, axis=-1)
    out_dat = make_dense_relu_dense(input_dat, torch_name, 'c', name="TODO_dat")
    out_pos = make_dense_relu_dense(input_pos, torch_name, 'p', name="TODO_pos")
    out = tf.compat.v1.concat([out_dat, out_pos], -1)
    return make_layer_norm(input + out, f'{torch_name}.layer_norm', f'{name}/layer_norm')

def make_stacks(input, nonpad_ids, dim_flat, dim_padded, valid_mask, num_stacks):
    res = input
    for i in range(num_stacks):
        res = make_attention(res, nonpad_ids, dim_flat, dim_padded, valid_mask, f'encoder.attn_{i}', name=f'attn_{i}')
        res = make_ff(res, f'encoder.ff_{i}', name=f'ff_{i}')
    return res

def make_layer_norm_with_constants(input, constants):
    # TODO(nikita): The epsilon here isn't quite the same as in pytorch
    # The pytorch code adds eps=1e-3 to the standard deviation, while this
    # tensorflow code adds eps=1e-6 to the variance.
    # However, the resulting mismatch in floating-point values does not seem to
    # translate to any noticable changes in the parser's tree output
    mean, variance = tf.compat.v1.nn.moments(input, [1], keep_dims=True)
    return tf.compat.v1.nn.batch_normalization(
        input,
        mean, variance,
        offset=constants[0],
        scale=constants[1],
        variance_epsilon=1e-6)

def make_flabel_with_constants(input, constants):
    mul1 = tf.compat.v1.matmul(input, constants[0])
    mul1b = tf.compat.v1.nn.bias_add(mul1, constants[1])
    mul1b = make_layer_norm_with_constants(mul1b, constants[2:4])
    mul1b = tf.compat.v1.nn.relu(mul1b)
    mul2 = tf.compat.v1.matmul(mul1b, constants[4])
    mul2b = tf.compat.v1.nn.bias_add(mul2, constants[5], name='flabel')
    return mul2b

def make_ftag(input):
    constants = (
        tf.compat.v1.constant(sd['f_tag.0.weight'].cpu().numpy().transpose()),
        tf.compat.v1.constant(sd['f_tag.0.bias'].cpu().numpy()),
        tf.compat.v1.constant(sd['f_tag.1.b_2'].cpu().numpy(), name="tag/layer_norm/offset"),
        tf.compat.v1.constant(sd['f_tag.1.a_2'].cpu().numpy(), name="tag/layer_norm/scale"),
        tf.compat.v1.constant(sd['f_tag.3.weight'].cpu().numpy().transpose()),
        tf.compat.v1.constant(sd['f_tag.3.bias'].cpu().numpy()),
    )
    mul1 = tf.compat.v1.matmul(input, constants[0])
    mul1b = tf.compat.v1.nn.bias_add(mul1, constants[1])
    mul1b = make_layer_norm_with_constants(mul1b, constants[2:4])
    mul1b = tf.compat.v1.nn.relu(mul1b)
    mul2 = tf.compat.v1.matmul(mul1b, constants[4])
    mul2b = tf.compat.v1.nn.bias_add(mul2, constants[5], name='ftag')
    return mul2b

def make_flabel_constants():
    return (
        tf.compat.v1.constant(sd['f_label.0.weight'].cpu().numpy().transpose()),
        tf.compat.v1.constant(sd['f_label.0.bias'].cpu().numpy()),
        tf.compat.v1.constant(sd['f_label.1.b_2'].cpu().numpy(), name="label/layer_norm/offset"),
        tf.compat.v1.constant(sd['f_label.1.a_2'].cpu().numpy(), name="label/layer_norm/scale"),
        tf.compat.v1.constant(sd['f_label.3.weight'].cpu().numpy().transpose()),
        tf.compat.v1.constant(sd['f_label.3.bias'].cpu().numpy()),
    )

def make_network():
    # batch x num_subwords
    input_ids = @@#placeholder(shape=(None, None), dtype=tf.compat.v1.int32, name='input_ids')
    word_end_mask = @@#placeholder(shape=(None, None), dtype=tf.compat.v1.int32, name='word_end_mask')
    input_dat, nonpad_ids, dim_flat, dim_padded, valid_mask, lengths = make_bert(input_ids, word_end_mask)
    input_pos_flat = tf.compat.v1.tile(position_table[:dim_padded[1]], [dim_padded[0], 1])
    input_pos = tf.compat.v1.gather(input_pos_flat, nonpad_ids)

    input_joint = tf.compat.v1.concat([input_dat, input_pos], -1)
    input_joint = make_layer_norm(input_joint, 'embedding.layer_norm', 'embedding/layer_norm')

    word_out = make_stacks(input_joint, nonpad_ids, dim_flat, dim_padded, valid_mask, num_stacks=parser.spec['hparams']['num_layers'])
    word_out = tf.compat.v1.concat([word_out[:, 0::2], word_out[:, 1::2]], -1)

    # part-of-speech predictions
    ftag = make_ftag(word_out)
    tags_packed = tf.compat.v1.argmax(ftag, axis=-1)
    tags = tf.compat.v1.reshape(
        tf.compat.v1.scatter_nd(indices=nonpad_ids[:, None], updates=tags_packed, shape=dim_flat),
        dim_padded
        )
    tags = tf.compat.v1.identity(tags, name="tags")

    fp_out = tf.compat.v1.concat([word_out[:-1,:512], word_out[1:,512:]], -1)

    fp_start_idxs = tf.compat.v1.cumsum(lengths, exclusive=True)
    fp_end_idxs = tf.compat.v1.cumsum(lengths) - 1 # the number of fenceposts is 1 less than the number of words

    fp_end_idxs_uneven = fp_end_idxs - tf.compat.v1.convert_to_tensor([1, 0])

    # Have to make these outside tf.compat.v1.map_fn for model compression to work
    constants = make_flabel_constants()

    def to_map(start_and_end):
        start, end = start_and_end
        fp = fp_out[start:end]
        flabel = make_flabel_with_constants(tf.compat.v1.reshape(fp[None,:,:] - fp[:,None,:], (-1, 1024)), constants)
        actual_chart_size = end-start
        flabel = tf.compat.v1.reshape(flabel, [actual_chart_size, actual_chart_size, -1])
        amount_to_pad = dim_padded[1] - actual_chart_size
        # extra padding on the label dimension is for the not-a-constituent label,
        # which always has a score of 0
        flabel = tf.compat.v1.pad(flabel, [[0, amount_to_pad], [0, amount_to_pad], [1, 0]])
        return flabel

    charts = tf.compat.v1.map_fn(to_map, (fp_start_idxs, fp_end_idxs), dtype=(tf.compat.v1.float32))
    charts = tf.compat.v1.identity(charts, name="charts")

    return input_ids, word_end_mask, charts, tags


In [12]:
def make_bert_instance(input_ids, input_mask, token_type_ids):
    # Transfer BERT config into tensorflow implementation
    config = bert.modeling.BertConfig.from_dict(parser.bert.config.to_dict())
    model = bert.modeling.BertModel(config=config, is_training=False,
        input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)

    # Next, transfer learned weights (after fine-tuning)
    bert_variables = [v for v in tf.compat.v1.get_collection('variables') if 'bert' in v.name]
    tf.compat.v1.variables_initializer(bert_variables).run()
    
    for variable in bert_variables:
        name = variable.name.split(':')[0]
        name = name.split('/')
        array = variable.eval()
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
        # which are not required for using pretrained model
        if any(n in ["adam_v", "adam_m"] for n in name):
            print("Skipping {}".format("/".join(name)))
            continue
        pytorch_var = parser
        for m_name in name:
            if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
                l = re.split(r'_(\d+)', m_name)
            else:
                l = [m_name]
            if l[0] == 'kernel' or l[0] == 'gamma':
                pytorch_var = getattr(pytorch_var, 'weight')
            elif l[0] == 'output_bias' or l[0] == 'beta':
                pytorch_var = getattr(pytorch_var, 'bias')
            elif l[0] == 'output_weights':
                pytorch_var = getattr(pytorch_var, 'weight')
            elif l[0] == 'cls':
                pytorch_var = getattr(pytorch_var, 'cls')
            else:
                pytorch_var = getattr(pytorch_var, l[0])
            if len(l) >= 2:
                try:
                    num = int(l[1])
                    pytorch_var = pytorch_var[num]
                except Exception as e:
                    print(pytorch_var)
                    raise
        if m_name[-11:] == '_embeddings':
            pytorch_var = getattr(pytorch_var, 'weight')
        elif m_name == 'kernel':
            pytorch_var = pytorch_var.t()
        try:
            assert pytorch_var.shape == array.shape
        except AssertionError as e:
            print(e)
            e.args += (pytorch_var.shape, array.shape)
            raise
            
        variable.load(pytorch_var.detach().cpu().numpy())
    return model

In [13]:
the_inp_tokens, the_inp_mask, the_out_chart, the_out_tags = make_network()

Instructions for updating:
Use keras.layers.Dense instead.
Instructions for updating:
Please use `layer.__call__` method instead.
Instructions for updating:
Prefer Variable.assign which has equivalent behavior in 2.X.
Instructions for updating:
Use tf.compat.v1.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
Use `tf.compat.v1.cast` instead.
Instructions for updating:
Use `tf.compat.v1.cast` instead.


In [14]:
from parse_nk import BERT_TOKEN_MAPPING
def bertify_batch(sentences):
    all_input_ids = np.zeros((len(sentences), parser.bert_max_len), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), parser.bert_max_len), dtype=int)

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        tokens.append("[CLS]")
        word_end_mask.append(1)

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            if word == "n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + "n"
                word = "'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = parser.bert_tokenizer.tokenize(word)
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append("[SEP]")
        word_end_mask.append(1)

        input_ids = parser.bert_tokenizer.convert_tokens_to_ids(tokens)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [15]:
from tqdm import tqdm

eval_batch_size = 16
test_predicted = []
for start_index in tqdm(range(0, len(test_treebank), eval_batch_size)):
    subbatch_trees = test_treebank[start_index:start_index+eval_batch_size]
    subbatch_sentences = [[leaf.word for leaf in tree.leaves()] for tree in subbatch_trees]
    inp_val_tokens, inp_val_mask = bertify_batch([[word for word in sentence] for sentence in subbatch_sentences])
    out_val_chart, out_val_tags = sess.run((the_out_chart, the_out_tags), 
                                       {the_inp_tokens: inp_val_tokens, the_inp_mask: inp_val_mask})
    trees = []
    scores = []
    for snum, sentence in enumerate(subbatch_sentences):
        chart_size = len(sentence) + 1
        tf_chart = out_val_chart[snum,:chart_size,:chart_size,:]
        sentence = list(zip([TAG_VOCAB[idx] for idx in out_val_tags[snum,1:chart_size]], [x for x in sentence]))
        tree, score = parser.decode_from_chart(sentence, tf_chart)
        trees.append(tree)
        scores.append(score)
    test_predicted.extend([p.convert() for p in trees])

100%|██████████| 28/28 [00:02<00:00, 11.31it/s]


In [16]:
test_fscore = evaluate.evalb('EVALB/', test_treebank[:len(test_predicted)], test_predicted)

<trees_newline.InternalTreebankNode object at 0x7f95a004a978> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x7f952c3d8a58>
<trees_newline.InternalTreebankNode object at 0x7f95a006d128> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x7f952c3db668>
<trees_newline.InternalTreebankNode object at 0x7f95a00084a8> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x7f952c3de198>
<trees_newline.InternalTreebankNode object at 0x7f95a001ac88> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x7f952c3decf8>
<trees_newline.InternalTreebankNode object at 0x7f958f708588> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x7f952c3e1908>
<trees_newline.InternalTreebankNode object at 0x7f958de87e48> <class 'trees_newline.InternalTreebankNode'> <trees_newline.InternalTreebankNode object at 0x

In [17]:
str(test_fscore)

'(Recall=74.89, Precision=78.79, FScore=76.79, CompleteMatch=9.01, TaggingAccuracy=91.17)'

In [18]:
input_node_names = [the_inp_tokens.name.split(':')[0], the_inp_mask.name.split(':')[0]]
output_node_names = [the_out_chart.name.split(':')[0], the_out_tags.name.split(':')[0]]

In [19]:
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names)
from tensorflow.tools.graph_transforms import TransformGraph
graph_def = TransformGraph(graph_def, input_node_names, output_node_names, [
'strip_unused_nodes()',
'remove_nodes(op=Identity, op=CheckNumerics)',
'fold_constants()',
'fold_old_batch_norms',
'fold_batch_norms',
'round_weights(num_steps=128)',
])

Instructions for updating:
Use `@@#graph_util.convert_variables_to_constants`
Instructions for updating:
Use `@@#graph_util.extract_sub_graph`
INFO:tensorflow:Froze 197 variables.
INFO:tensorflow:Converted 197 variables to const ops.


In [20]:
with open('export/model-tiny.pb', 'wb') as f:
    f.write(graph_def.SerializeToString())

In [21]:
import json

with open('vocab-tiny.json', 'w') as fopen:
    json.dump({'label': LABEL_VOCAB, 'tag': TAG_VOCAB}, fopen)