In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import sys

SOURCE_DIR = os.path.dirname(os.path.abspath(__name__))
sys.path.insert(0, f"{SOURCE_DIR}/src")
sys.path.append(SOURCE_DIR)

In [3]:
import tensorflow as tf

In [4]:
from transformers import XLNetTokenizer
tokenizer = XLNetTokenizer.from_pretrained(
    'huseinzol05/xlnet-base-bahasa-cased', do_lower_case = False
)

In [5]:
import json

with open('vocab-xlnet-base.json') as fopen:
    data = json.load(fopen)
    
LABEL_VOCAB = data['label']
TAG_VOCAB = data['tag']

In [6]:
with tf.compat.v1.gfile.GFile('export/xlnet-base.pb.quantized', 'rb') as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())
    with tf.compat.v1.Graph().as_default() as graph:
        tf.compat.v1.import_graph_def(graph_def)

In [7]:
input_ids = graph.get_tensor_by_name('import/input_ids:0')
word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')
charts = graph.get_tensor_by_name('import/charts:0')
tags = graph.get_tensor_by_name('import/tags:0')
sess = tf.compat.v1.InteractiveSession(graph = graph)

In [8]:
BERT_MAX_LEN = 512
import numpy as np
from parse_nk import BERT_TOKEN_MAPPING

def make_feed_dict_bert(sentences):
    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)

    subword_max_len = 0
    for snum, sentence in enumerate(sentences):
        tokens = []
        word_end_mask = []

        cleaned_words = []
        for word in sentence:
            word = BERT_TOKEN_MAPPING.get(word, word)
            if word == "n't" and cleaned_words:
                cleaned_words[-1] = cleaned_words[-1] + "n"
                word = "'t"
            cleaned_words.append(word)

        for word in cleaned_words:
            word_tokens = tokenizer.tokenize(word)
            for _ in range(len(word_tokens)):
                word_end_mask.append(0)
            word_end_mask[-1] = 1
            tokens.extend(word_tokens)
        tokens.append("<sep>")
        word_end_mask.append(1)
        tokens.append("<cls>")
        word_end_mask.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)

        subword_max_len = max(subword_max_len, len(input_ids))

        all_input_ids[snum, :len(input_ids)] = input_ids
        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

    all_input_ids = all_input_ids[:, :subword_max_len]
    all_word_end_mask = all_word_end_mask[:, :subword_max_len]
    return all_input_ids, all_word_end_mask

In [9]:
s = 'Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar sekiranya mengantuk ketika memandu.'.split()
sentences = [s]
i, m = make_feed_dict_bert(sentences)
i, m

(array([[  383,  1096, 21767,    88,   757,  1606, 15738,    24,   198,
          4049,  2479,  7529,   271,  7644,     9,     4,     3]]),
 array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]]))

In [10]:
charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})
charts_val, tags_val

(array([[[[ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,
           -3.2888124 , -2.3206522 ],
          [ 0.        , -3.027123  , -4.226623  , ..., -2.8357825 ,
           -1.8898286 , -1.4015231 ],
          [ 0.        , -4.4751773 , -4.0391564 , ..., -2.0557482 ,
           -2.5116968 , -2.3222225 ],
          ...,
          [ 0.        , -3.0751634 , -5.0968566 , ..., -2.8086288 ,
           -2.4999511 , -2.1263318 ],
          [ 0.        , -0.21056579, -3.830736  , ..., -2.8927965 ,
           -3.2552593 , -2.4863706 ],
          [ 0.        ,  0.        ,  0.        , ...,  0.        ,
            0.        ,  0.        ]],
 
         [[ 0.        , -2.7311478 , -3.266908  , ..., -2.5407348 ,
           -2.7771876 , -3.4675407 ],
          [ 0.        , -6.4940166 , -4.012821  , ..., -2.4423234 ,
           -3.2888124 , -2.3206522 ],
          [ 0.        , -3.5888338 , -3.3573806 , ..., -1.5800554 ,
           -3.0464988 , -3.0329437 ],
          ...,
          [ 0

In [11]:
for snum, sentence in enumerate(sentences):
    chart_size = len(sentence) + 1
    chart = charts_val[snum,:chart_size,:chart_size,:]

In [12]:
# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx

In [13]:
import chart_decoder_py

In [14]:
chart_decoder_py.decode(chart)

(20.88371,
 array([ 0,  0,  0,  1,  2,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,
         8,  8,  9,  9, 10, 10, 11, 11, 12, 13]),
 array([14,  2,  1,  2, 14, 13,  3, 13,  4, 13,  5, 13,  6, 13,  7, 13,  8,
        13,  9, 13, 10, 13, 11, 13, 12, 13, 14]),
 array([ 1,  4,  0,  4,  0,  5,  0,  0,  0,  7,  0,  5,  0,  5,  5,  0,  0,
         5,  0,  0, 13,  7,  0,  1,  4,  0,  0]))

In [15]:
import nltk
from nltk import Tree

In [16]:
PTB_TOKEN_ESCAPE = {u"(": u"-LRB-",
    u")": u"-RRB-",
    u"{": u"-LCB-",
    u"}": u"-RCB-",
    u"[": u"-LSB-",
    u"]": u"-RSB-"}


def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):

    # Python 2 doesn't support "nonlocal", so wrap idx in a list
    idx_cell = [-1]
    def make_tree():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            tree = Tree(tag, [word])
            for sublabel in label[::-1]:
                tree = Tree(sublabel, [tree])
            return [tree]
        else:
            left_trees = make_tree()
            right_trees = make_tree()
            children = left_trees + right_trees
            if label:
                tree = Tree(label[-1], children)
                for sublabel in reversed(label[:-1]):
                    tree = Tree(sublabel, [tree])
                return [tree]
            else:
                return children

    tree = make_tree()[0]
    tree.score = score
    return tree

In [17]:
tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))
print(str(tree))

(S
  (NP-SBJ (<START> Dr) (NP-SBJ (CC Mahathir)))
  (VP
    (NNP menasihati)
    (VB mereka)
    (SBAR
      (PRP supaya)
      (VP
        (IN berhenti)
        (VP
          (VP (VB berehat))
          (VB dan)
          (VP
            (CC tidur)
            (ADVP (VB sebentar))
            (SBAR
              (RB sekiranya)
              (S (NP-SBJ (IN mengantuk)) (NN ketika))))))))
  (NN memandu.))


In [18]:
def make_str_tree(sentence, tags, score, p_i, p_j, p_label):
    idx_cell = [-1]
    def make_str():
        idx_cell[0] += 1
        idx = idx_cell[0]
        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
        label = LABEL_VOCAB[label_idx]
        if (i + 1) >= j:
            word = sentence[i]
            tag = TAG_VOCAB[tags[i]]
            tag = PTB_TOKEN_ESCAPE.get(tag, tag)
            word = PTB_TOKEN_ESCAPE.get(word, word)
            s = u"({} {})".format(tag, word)
        else:
            children = []
            while ((idx_cell[0] + 1) < len(p_i)
                and i <= p_i[idx_cell[0] + 1]
                and p_j[idx_cell[0] + 1] <= j):
                children.append(make_str())

            s = u" ".join(children)
            
        for sublabel in reversed(label):
            s = u"({} {})".format(sublabel, s)
        return s
    return make_str()

In [19]:
make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))

'(S (NP-SBJ (<START> Dr) (NP-SBJ (CC Mahathir))) (VP (NNP menasihati) (VB mereka) (SBAR (PRP supaya) (VP (IN berhenti) (VP (VP (VB berehat)) (VB dan) (VP (CC tidur) (ADVP (VB sebentar)) (SBAR (RB sekiranya) (S (NP-SBJ (IN mengantuk)) (NN ketika)))))))) (NN memandu.))'