In [1]:
from nltk.tree import Tree

In [2]:
tree_string = '(S (NP (DT The) (NNS owls)) (VP (VBP are) (RB not) (SBAR (WHNP (WP what)) (S (NP (PRP they)) (VP (VBP seem))))) (PUNCT .))'
tree = Tree.fromstring(tree_string)
tree.pretty_print()

                  S                          
      ____________|_______________________    
     |                VP                  |  
     |         _______|____               |   
     |        |   |       SBAR            |  
     |        |   |    ____|____          |   
     |        |   |   |         S         |  
     |        |   |   |     ____|___      |   
     NP       |   |  WHNP  NP       VP    |  
  ___|___     |   |   |    |        |     |   
 DT     NNS  VBP  RB  WP  PRP      VBP  PUNCT
 |       |    |   |   |    |        |     |   
The     owls are not what they     seem   .  



Naive absolute encodes the number of commons between words w<sub>i</sub> and w<sub>i+1</sub>. This number is in the label associated to w<sub>i</sub>.

In [3]:
from src.encs.enc_const.naive_absolute import C_NaiveAbsoluteEncoding
from src.models.const_tree import ConstituentTree

encoder = C_NaiveAbsoluteEncoding(separator="_", unary_joiner="+")
constituent_tree = ConstituentTree.from_string(tree_string)
w, p, l, f = encoder.encode(constituent_tree)

for i in range(len(w)):
    n, lc = l[i].n_commons, l[i].last_common
    print("Word '{}' has {} common ancestors with his next word. The last common ancestor is {}".format(w[i], n, lc))


Word 'The' has 2 common ancestors with his next word. The last common ancestor is NP
Word 'owls' has 1 common ancestors with his next word. The last common ancestor is S
Word 'are' has 2 common ancestors with his next word. The last common ancestor is VP
Word 'not' has 2 common ancestors with his next word. The last common ancestor is VP
Word 'what' has 3 common ancestors with his next word. The last common ancestor is SBAR
Word 'they' has 4 common ancestors with his next word. The last common ancestor is S
Word 'seem' has 1 common ancestors with his next word. The last common ancestor is S
Word '.' has 1 common ancestors with his next word. The last common ancestor is S


We want to encode in the word w<sub>i</sub> the number of commons between word w<sub>i-1</sub> and w<sub>i</sub>.

```
Word 'The' has 1 common ancestors with his previous word. The last common ancestor is S.
Word 'owls' has 2 common ancestors with his previous word. The last common ancestor is NP.
Word 'are' has 1 common ancestors with his previous word. The last common ancestor is S.
Word 'not' has 2 common ancestors with his previous word. The last common ancestor is VP.
Word 'what' has 2 common ancestors with his previous word. The last common ancestor is VP.
Word 'they' has 3 common ancestors with his previous word. The last common ancestor is SBAR.
Word 'seem' has 4 common ancestors with his previous word. The last common ancestor is S.
Word '.' has 1 common ancestors with his previous word. The last common ancestor is S.
(...)
```


In [14]:
from src.encs.abstract_encoding import ACEncoding
from src.utils.constants import C_ABSOLUTE_ENCODING, C_ROOT_LABEL, C_CONFLICT_SEPARATOR, C_NONE_LABEL, C_DUMMY_START, C_DUMMY_END
from src.models.const_label import ConstituentLabel
from src.models.const_tree import ConstituentTree

import re

class C_NaiveIncremental(ACEncoding):
    def __init__(self, separator, unary_joiner):
        self.separator = separator
        self.unary_joiner = unary_joiner

    def get_unary_chain(self, postag):
        unary_chain = None
        leaf_unary_chain = postag.split(self.unary_joiner)

        if len(leaf_unary_chain)>1:
            unary_list = []
            for element in leaf_unary_chain[:-1]:
                unary_list.append(element.split("##")[0])

            unary_chain = self.unary_joiner.join(unary_list)
            postag = leaf_unary_chain[len(leaf_unary_chain)-1]
        
        return unary_chain, postag
    
    def get_features(self, node, feature_marker="##", feature_splitter="|"):
        postag_split = node.split(feature_marker)
        feats = None

        if len(postag_split) > 1:
            postag = re.sub(r'[0-9]+', '', postag_split[0])
            feats = postag_split[1].split(feature_splitter)
        else:
            postag = re.sub(r'[0-9]+', '', node)
        return postag, feats
    
    def clean_last_common(self, node, feature_marker="##"):
        node = re.sub(r'[0-9]+', '', node)
        last_common = node.split(feature_marker)[0]
        return last_common

    def encode(self, constituent_tree):
        constituent_tree.reverse_tree()
        leaf_paths = constituent_tree.path_to_leaves(collapse_unary=True, unary_joiner=self.unary_joiner, dummy=C_DUMMY_END)
        labels=[]
        words=[]
        postags=[]
        additional_feats=[]

        for i in range(1, len(leaf_paths)):
            path_a = leaf_paths[i-1]
            path_b = leaf_paths[i]
            
            last_common = ""
            n_commons   = 0

            for a,b in zip(path_a, path_b):
                if (a!=b):
                    # Remove the digits and aditional feats in the last common node
                    last_common = self.clean_last_common(last_common)

                    # Get word and POS tag
                    word   = path_a[-1]
                    postag = path_a[-2]
                    
                    # Build the Leaf Unary Chain
                    unary_chain, postag = self.get_unary_chain(postag)
                    
                    # Clean the POS Tag and extract additional features
                    postag, feats = self.get_features(postag)

                    # Append the data
                    labels.append(ConstituentLabel(n_commons, last_common, unary_chain, C_ABSOLUTE_ENCODING, self.separator, self.unary_joiner))
                    words.append(word)
                    postags.append(postag)
                    additional_feats.append(feats)

                    break
                
                # Store Last Common and increase n_commons 
                # Note: When increasing n_commons use the number from split the collapsed chains
                n_commons  += len(a.split(self.unary_joiner))
                last_common = a
        
        # reverse and return
        words.reverse(); postags.reverse(); labels.reverse(); additional_feats.reverse()
        return words, postags, labels, additional_feats



    def decode(self, linearized_tree):
        # Check valid labels 
        if not linearized_tree:
            print("[*] Error while decoding: Null tree.")
            return

        # Create constituent tree
        tree = ConstituentTree(C_ROOT_LABEL, [])
        current_level = tree

        old_n_commons=0
        old_level=None

        linearized_tree.reverse()
        for row in linearized_tree:
            word, postag, label = row
            
            # Descend through the tree until reach the level indicated by last_common
            current_level = tree
            for level_index in range(label.n_commons):
                if (current_level.is_terminal()) or (level_index >= old_n_commons):
                    current_level.add_child(ConstituentTree(C_NONE_LABEL, []))
                
                current_level = current_level.r_child()

            # Split the Last Common field of the Label in case it has a Unary Chain Collapsed
            label.last_common = label.last_common.split(self.unary_joiner)

            if len(label.last_common) == 1:
                # If current level has no label yet, put the label
                # If current level has label but different than this one, set it as a conflict
                if (current_level.label == C_NONE_LABEL):
                    current_level.label = label.last_common[0]
                else:
                    current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[0]
            else:
                current_level = tree
                
                # Descend to the beginning of the Unary Chain and fill it
                descend_levels = label.n_commons - (len(label.last_common)) + 1
                
                for level_index in range(descend_levels):
                    current_level = current_level.r_child()
                
                for i in range(len(label.last_common)-1):
                    if (current_level.label == C_NONE_LABEL):
                        current_level.label = label.last_common[i]
                    else:
                        current_level.label = current_level.label + C_CONFLICT_SEPARATOR + label.last_common[i]
                    current_level = current_level.r_child()

                # If we reach a POS tag, set it as child of the current chain
                if current_level.is_preterminal():
                    temp_current_level = current_level
                    current_level.label = label.last_common[i+1]
                    current_level.children = [temp_current_level]
                
                else:
                    current_level.label=label.last_common[i+1]
            
            # Fill POS tag in this node or previous one
            if (label.n_commons >= old_n_commons):
                current_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)
            
            else:
                old_level.fill_pos_nodes(postag, word, label.unary_chain, self.unary_joiner)

            old_n_commons=label.n_commons
            old_level=current_level

        tree=tree.children[0]
        tree.reverse_tree()
        return tree

An approach to this could be to get the path to leaves in reverse order, efectivelly using the same algorithm as naive absolute / naive relative but now w<sub>i-1</sub> will be w<sub>i+1</sub>

In [5]:
constituent_tree = ConstituentTree.from_string(tree_string)
path_to_leaves = constituent_tree.path_to_leaves(dummy=C_DUMMY_START)
for p in path_to_leaves:
    print(p)


['S0', 'PUNCT0', '.']
['S0', 'VP1', 'SBAR1', 'S1', 'VP+VBP1', 'seem']
['S0', 'VP1', 'SBAR1', 'S1', 'NP+PRP2', 'they']
['S0', 'VP1', 'SBAR1', 'WHNP+WP2', 'what']
['S0', 'VP1', 'RB2', 'not']
['S0', 'VP1', 'VBP3', 'are']
['S0', 'NP2', 'NNS2', 'owls']
['S0', 'NP2', 'DT3', 'The']
['S0', '-START-']


In [11]:
incr_enc = C_NaiveIncremental(separator="_", unary_joiner="+")
constituent_tree = ConstituentTree.from_string(tree_string)

w, p, l, f = incr_enc.encode(constituent_tree)
for i in range(len(w)):
    n, lc = l[i].n_commons, l[i].last_common
    print("Word '{}' has {} common ancestors with his previous word. The last common ancestor is {}".format(w[i], n, lc))

Word 'The' has 1 common ancestors with his previous word. The last common ancestor is S
Word 'owls' has 2 common ancestors with his previous word. The last common ancestor is NP
Word 'are' has 1 common ancestors with his previous word. The last common ancestor is S
Word 'not' has 2 common ancestors with his previous word. The last common ancestor is VP
Word 'what' has 2 common ancestors with his previous word. The last common ancestor is VP
Word 'they' has 3 common ancestors with his previous word. The last common ancestor is SBAR
Word 'seem' has 4 common ancestors with his previous word. The last common ancestor is S
Word '.' has 1 common ancestors with his previous word. The last common ancestor is S


As we will have the whole tree during the decoding process, we could also reverse the order of the linearized tree rows and implement decoding backwards.

In [18]:
from src.utils.constants import C_STRAT_MAX
tree_string = '(S (NP (DT The) (NNS owls)) (VP (VBP are) (RB not) (SBAR (WHNP (WP what)) (S (NP (PRP they)) (VP (VBP seem))))) (PUNCT .))'
tree = Tree.fromstring(tree_string)

print("\n>> Original Tree")
tree.pretty_print()

incr_enc = C_NaiveIncremental(separator="_", unary_joiner="+")
constituent_tree = ConstituentTree.from_string(tree_string)

w, p, l, f = incr_enc.encode(constituent_tree)
linearized_tree = [(wi, pi, li) for wi, pi, li, fi in zip(w, p, l, f)]

print("\n>> Linearized Tree\n")
for line in linearized_tree:
    print(line)
decoded_tree = incr_enc.decode(linearized_tree)
decoded_tree.postprocess_tree(conflict_strat=C_STRAT_MAX, clean_nulls=True)
tree = Tree.fromstring(str(decoded_tree))

print("\n>> Decoded Tree")
tree.pretty_print()


>> Original Tree
                  S                          
      ____________|_______________________    
     |                VP                  |  
     |         _______|____               |   
     |        |   |       SBAR            |  
     |        |   |    ____|____          |   
     |        |   |   |         S         |  
     |        |   |   |     ____|___      |   
     NP       |   |  WHNP  NP       VP    |  
  ___|___     |   |   |    |        |     |   
 DT     NNS  VBP  RB  WP  PRP      VBP  PUNCT
 |       |    |   |   |    |        |     |   
The     owls are not what they     seem   .  


>> Linearized Tree

('The', 'DT', 1_S)
('owls', 'NNS', 2_NP)
('are', 'VBP', 1_S)
('not', 'RB', 2_VP)
('what', 'WP', 2_VP_WHNP)
('they', 'PRP', 3_SBAR_NP)
('seem', 'VBP', 4_S_VP)
('.', 'PUNCT', 1_S)

>> Decoded Tree
                  S                          
      ____________|_______________________    
     |                VP                  |  
     |         _______