In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import sys
sys.path.insert(0, "/home/husein/parsing/self-attentive-parser/src")
sys.path.append("/home/husein/parsing/self-attentive-parser")

In [3]:
import tensorflow as tf
from transformers import AlbertTokenizer

In [4]:
from transformers import AlbertTokenizer
tokenizer = AlbertTokenizer.from_pretrained(
    'huseinzol05/albert-base-bahasa-cased',
    do_lower_case = False,
)

In [5]:
import json

with open('vocab-albert-base.json') as fopen:
    data = json.load(fopen)
    
LABEL_VOCAB = data['label']
TAG_VOCAB = data['tag']

In [8]:
with tf.compat.v1.gfile.GFile('export/albert-base.pb', 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)

In [9]:
input_ids = graph.get_tensor_by_name('import/input_ids:0')
word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')
charts = graph.get_tensor_by_name('import/charts:0')
tags = graph.get_tensor_by_name('import/tags:0')
sess = tf.compat.v1.InteractiveSession(graph = graph)

In [10]:
BERT_MAX_LEN = 512
import numpy as np
from parse_nk import BERT_TOKEN_MAPPING

def make_feed_dict_bert(sentences):
    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        tokens.append(u"[CLS]")
        word_end_mask.append(1)

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            # BERT is pre-trained with a tokenizer that doesn't split off
            # n't as its own token
            if word == u"n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + u"n"
                word = u"'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            if not word_tokens:
                # The tokenizer used in conjunction with the parser may not
                # align with BERT; in particular spaCy will create separate
                # tokens for whitespace when there is more than one space in
                # a row, and will sometimes separate out characters of
                # unicode category Mn (which BERT strips when do_lower_case
                # is enabled). Substituting UNK is not strictly correct, but
                # it's better than failing to return a valid parse.
                word_tokens = ["[UNK]"]
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append(u"[SEP]")
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [11]:
s = 'Saya sedang membaca buku tentang Perlembagaan'.split()
sentences = [s]
i, m = make_feed_dict_bert(sentences)
i, m

(array([[   2,  280,  457, 1520,  597,  454, 3794,    3]]),
 array([[1, 1, 1, 1, 1, 1, 1, 1]]))

In [12]:
charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})
charts_val, tags_val

(array([[[[ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,
           -2.3575218 , -2.1976795 ],
          [ 0.        , -1.9336585 , -2.5753968 , ..., -1.7595727 ,
           -1.7113284 , -1.9766747 ],
          [ 0.        , -1.8684319 , -2.5347548 , ..., -2.1521864 ,
           -1.9502559 , -2.4117303 ],
          ...,
          [ 0.        , -2.4688456 , -5.0541735 , ..., -2.1501265 ,
           -2.7817178 , -2.045808  ],
          [ 0.        ,  2.0402954 , -1.3372381 , ..., -2.0561574 ,
           -2.2189684 , -2.2715192 ],
          [ 0.        ,  0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        , -3.761474  , -2.1691635 , ..., -2.5855842 ,
           -2.1737652 , -2.256596  ],
          [ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,
           -2.3575218 , -2.1976795 ],
          [ 0.        , -3.0972695 , -1.830802  , ..., -2.2955809 ,
           -2.1975942 , -2.3570776 ],
          ...,
          [ 0

In [13]:
for snum, sentence in enumerate(sentences):
    chart_size = len(sentence) + 1
    chart = charts_val[snum,:chart_size,:chart_size,:]

In [14]:
# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx

In [15]:
import chart_decoder_py

In [16]:
chart_decoder_py.decode(chart)

(17.99921,
 array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]),
 array([6, 1, 6, 2, 6, 3, 6, 4, 6, 5, 6]),
 array([1, 4, 5, 0, 5, 0, 3, 3, 2, 0, 3]))

In [17]:
import nltk
from nltk import Tree

In [18]:
PTB_TOKEN_ESCAPE = {u"(": u"-LRB-",
    u")": u"-RRB-",
    u"{": u"-LCB-",
    u"}": u"-RCB-",
    u"[": u"-LSB-",
    u"]": u"-RSB-"}


def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):

    # Python 2 doesn't support "nonlocal", so wrap idx in a list
    idx_cell = [-1]
    def make_tree():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            tree = Tree(tag, [word])
            for sublabel in label[::-1]:
                tree = Tree(sublabel, [tree])
            return [tree]
        else:
            left_trees = make_tree()
            right_trees = make_tree()
            children = left_trees + right_trees
            if label:
                tree = Tree(label[-1], children)
                for sublabel in reversed(label[:-1]):
                    tree = Tree(sublabel, [tree])
                return [tree]
            else:
                return children

    tree = make_tree()[0]
    tree.score = score
    return tree

In [19]:
tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))
print(str(tree))

(S
  (NP-SBJ (<START> Saya))
  (VP
    (PRP sedang)
    (VP
      (MD membaca)
      (NP (NP (VB buku)) (PP (NN tentang) (NP (IN Perlembagaan)))))))


In [20]:
def make_str_tree(sentence, tags, score, p_i, p_j, p_label):
    idx_cell = [-1]
    def make_str():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            s = u"({} {})".format(tag, word)
        else:
            children = []
            while ((idx_cell[0] + 1) < len(p_i)
                and i <= p_i[idx_cell[0] + 1]
                and p_j[idx_cell[0] + 1] <= j):
                children.append(make_str())

            s = u" ".join(children)
            
        for sublabel in reversed(label):
            s = u"({} {})".format(sublabel, s)
        return s
    return make_str()

In [22]:
make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))

'(S (NP-SBJ (<START> Saya)) (VP (PRP sedang) (VP (MD membaca) (NP (NP (VB buku)) (PP (NN tentang) (NP (IN Perlembagaan)))))))'