In [1]:
from src.models.const_tree import C_Tree
from src.encs.constituent import *
from src.utils.constants import *
from src.models.linearized_tree import LinearizedTree
from nltk.tree import Tree

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tree_string = "(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))"
gold_tree = C_Tree.from_string(tree_string)

nltk_t = Tree.fromstring(str(gold_tree))
nltk_t.pretty_print()

# Encode it as binary
binary_tree = C_Tree.to_binary(gold_tree)
binary_tree = binary_tree.collapse_unary()
binary_tree = binary_tree.remove_preterminals()

binary_tree.add_root_node()
print(str(binary_tree))
nltk_t = Tree.fromstring(str(binary_tree))
nltk_t.pretty_print()

              S                          
  ____________|________________________   
 |    |   |       VP                   | 
 |    |   |    ___|_________           |  
INTJ  |   NP  |   |         NP         | 
 |    |   |   |   |     ____|____      |  
 RB   ,  PRP VBD  RB  NNP       NNP    . 
 |    |   |   |   |    |         |     |  
 No   ,   it was n't Black     Monday  . 

(-ROOT- (S No (S* , (S** it (S*** (VP was (VP* n't (NP Black Monday))) .)))))
    -ROOT-                                      
      |                                          
      S                                         
  ____|_____                                     
 |          S*                                  
 |     _____|_______                             
 |    |            S**                          
 |    |      _______|___                         
 |    |     |          S***                     
 |    |     |    _______|____                    
 |    |     |   |            VP             

In [21]:
from src.models.const_label import C_Label
from src.models.linearized_tree import LinearizedTree

import re

directions_dir = {"lL":0,"lR":1,"rL":2,"rR":3}
def get_child_directions(t):
    labels = []
    for node in t.children:
        if node.is_terminal():
            label_string = ""

            if node.is_right_child():
                label_string+="l"
            elif node.is_left_child():
                label_string+="r"
            
            if node.parent.is_right_child():
                label_string+="L"
            elif node.parent.is_left_child():
                label_string+="R"
            
            labels.append(label_string)
        child_labels = get_child_directions(node)
        for label in child_labels if child_labels != [] else []:
            labels.append(label)
    
    return labels

def encode(constituent_tree):        
    lc_tree = LinearizedTree.empty_tree()

    # Compute the Binary Tree and the arrows
    binary_tree = C_Tree.to_binary(constituent_tree)
    
    binary_tree_collapsed = binary_tree.collapse_unary()
    binary_tree_collapsed = binary_tree.remove_preterminals()
    binary_tree_collapsed = C_Tree(C_ROOT_LABEL, binary_tree)
    child_dirs = get_child_directions(binary_tree_collapsed)

    # Compute the number of commons
    leaf_paths = binary_tree.path_to_leaves(collapse_unary=True, unary_joiner="+")
    
    for i in range(0, len(leaf_paths)-1):
        path_a = leaf_paths[i]
        path_b = leaf_paths[i+1]
        
        last_common = ""
        for a,b in zip(path_a, path_b):
            if (a!=b):
                # Remove the digits and aditional feats in the last common node
                last_common = re.sub(r'[0-9]+', '', last_common)
                last_common = last_common.split("##")[0]

                # Get word and POS tag
                word = path_a[-1]
                postag = path_a[-2]
                
                # Build the Leaf Unary Chain
                unary_chain = None
                leaf_unary_chain = postag.split("+")
                if len(leaf_unary_chain)>1:
                    unary_list = []
                    for element in leaf_unary_chain[:-1]:
                        unary_list.append(element.split("##")[0])

                    unary_chain ="+".join(unary_list)
                    postag = leaf_unary_chain[len(leaf_unary_chain)-1]
                
                # Clean the POS Tag and extract additional features
                postag_split = postag.split("##")
                feats = [None]

                if len(postag_split) > 1:
                    postag = re.sub(r'[0-9]+', '', postag_split[0])
                    feats = postag_split[1].split("|")
                else:
                    postag = re.sub(r'[0-9]+', '', postag)

                direction = child_dirs[i]
                c_label = C_Label(direction, last_common, unary_chain, C_TETRA_ENCODING, "_", "+")
                
                # Append the data
                lc_tree.add_row(word, postag, feats, c_label)
            
                break            
            last_common = a
        
    # n = max number of features of the tree
    lc_tree.n = max([len(f) for f in lc_tree.additional_feats])
    return lc_tree

tree_string = "(S (INTJ (RB No)) (, ,) (NP (PRP it)) (VP (VBD was) (RB n't) (NP (NNP Black) (NNP Monday))) (. .))"
gold_tree = C_Tree.from_string(tree_string)

lc = encode(gold_tree)
lc

-BOS-	-BOS-	-BOS-
No	RB	rR_S_INTJ
,	,	rL_S*
it	PRP	rR_S**_NP
was	VBD	rR_VP
n't	RB	rL_VP*
Black	NNP	rL_NP
Monday	NNP	lL_S***
.	.	lL_S
-EOS-	-EOS-	-EOS-

In [None]:
import copy

def pop_first(buffer):
    if len(buffer)==0:
        raise "Error: Empty buffer"
    else:
        leaf = buffer[0]
        del buffer[0]
    return leaf, buffer

def push_last(stack, leaf):
    stack.append(leaf)
    return stack

def pop_last(stack):
    if len(stack) == 0:
        raise "Error: Empty stack"
    else:
        leaf = stack[-1]
        del stack[-1]
    return leaf, stack

def combine(tree, new_child):
    '''
    Replaces a C_NONE_LABEL inside 'tree'
    with new_child
    '''
    # trees should have only 2 child nodes
    if type(new_child) is str:
        new_child = C_Tree(new_child)
    
    current_level = tree
    
    while(not current_level.has_none_child()):
        current_level = current_level.r_child()
    
    print("filling...")
    if current_level.children[0].label == C_NONE_LABEL:
        current_level.children[0] = new_child
    elif current_level.children[1].label == C_NONE_LABEL:
        current_level.children[1] = new_child
    return tree

def make_node(tag, lchild, rchild):
    if type(lchild) is not C_Tree:
        lchild = C_Tree(lchild)
    if type(rchild) is not C_Tree:
        lchild = C_Tree(rchild)
    t = C_Tree(tag, [lchild, rchild])
    return t

# Decoding according to paper
# base_tree = C_Tree(C_ROOT_LABEL)
# base_tree.add_child([C_Tree.empty_tree(),C_Tree.empty_tree()])
stack = []
buffer = copy.deepcopy(lc.words)
tree = None
for word, postag, feats, label in lc.iterrows():
    a, t, uc = label.n_commons, label.last_common, label.unary_chain
    a1, a2 = a[0], a[1]
    print("W = [",word,"]: ",a1,a2,t)
    if len(stack)>=1:
        print("S = [",stack[-1],"]")
    else:
        print("S = [ ]")

    if a1 == "r":
        leaf, buffer = pop_first(buffer)
        stack = push_last(stack, leaf)
    
    if a1 == "l":
        leaf, buffer = pop_first(buffer)
        stack[-1] = combine(stack[-1], leaf)


    # whenever the buffer is empty, we STOP. We should perfom
    # (2 * |w| - 1) actions
    
    if len(buffer)==0:
        break


    if a2 == "R":
        print(stack, t)
        stack[-1] = make_node(t, stack[-1], C_Tree.empty_tree())
        
    if a2 == "L":
        tree, stack = pop_last(stack)
        print("Applying make node from L to ", tree)
        tree = make_node(t, tree, C_Tree.empty_tree())
        stack[-1] = combine(stack[-1], tree)

    nltk_t = Tree.fromstring(str(stack[-1]))
    nltk_t.pretty_print()

print("done")
nltk_t = Tree.fromstring(str(stack[0]))
nltk_t.pretty_print()


W = [ No ]:  r R S
S = [ ]
['No'] S
     S        
  ___|____     
 No     -NONE-

W = [ , ]:  r L S*
S = [ (S No -NONE-) ]
Applying make node from L to  ,
filling...
     S            
  ___|___          
 |       S*       
 |    ___|____     
 No  ,      -NONE-

W = [ it ]:  r L S**
S = [ (S No (S* , -NONE-)) ]
Applying make node from L to  it
filling...
         S                
  _______|___              
 |           S*           
 |    _______|___          
 |   |          S**       
 |   |        ___|____     
 No  ,       it     -NONE-

W = [ was ]:  r R VP
S = [ (S No (S* , (S** it -NONE-))) ]
[(S No (S* , (S** it -NONE-))), 'was'] VP
     VP       
  ___|____     
was     -NONE-

W = [ n't ]:  r L VP*
S = [ (VP was -NONE-) ]
Applying make node from L to  n't
filling...
     VP           
  ___|___          
 |      VP*       
 |    ___|____     
was n't     -NONE-

W = [ Black ]:  r L NP
S = [ (VP was (VP* n't -NONE-)) ]
Applying make node from L to  Black
filling...
       