In [1]:
from nltk.tree import Tree
import codecs
import re

In [2]:
# load data
f = codecs.open("ctb.bracketed","r","utf-8")
sents = []
posit = []
while True:
    s = f.readline()
    if s == '':
        break
    if s[0] == '#':
        posit += [s.strip()]
    else:
        sents += [s.strip()]

In [3]:
# trans to trees
trees = []
for t in sents:
    tr = Tree.fromstring(t)
    trees += [tr]

In [4]:
import re


class MyTree:
    def __init__(self, tree, parent_node=None, current_layer=0):
        self.parent = parent_node
        self.layer = current_layer
        if isinstance(tree, str):
            self.label = tree
            self.terminal = True
            self.children = None
            self.n_children = 0
            return
        self.label = tree.label()
        self.terminal = False
        self.children = []
        for son in tree:
            _node = MyTree(son, self, current_layer + 1)
            self.children += [_node]
        self.n_children = len(self.children)
        return

    def __getitem__(self, i):
        return self.children[i]

    def labelDependencyTree(self, current_n_words=1):
        # out nodes:
        if self.terminal:
            self.id = current_n_words
            return current_n_words + 1
        for son in self.children:
            current_n_words = son.labelDependencyTree(current_n_words)
        self.id = self.findHead().id
        return current_n_words

    def tryFetchHead(self, try_labels, reverse=False):
        if isinstance(try_labels, str):
            try_labels = (try_labels,)
        if reverse:
            for lab in try_labels:
                for i in reversed(range(self.n_children)):
                    if re.match(lab, self.children[i].label):
                        return self.children[i]
        else:
            for lab in try_labels:
                for t in self.children:
                    if re.match(lab, t.label):
                        return t
        return None

    def findHead(self):
        if self.terminal:
            raise AttributeError("cannot call self.findHead() on terminal nodes.")
        if len(self.children) == 1:
            return self.children[0]

        t = self.label

        # coordination
        head = self.tryFetchHead("CC", reverse=True)
        if head:
            return head

        template = {
            "VP": ("V|BA|LB", "ADVP", "QP", "IP", "NP"),
            "VCD|VCP|VNV|VPT|VRD|VSB": "V",
            "TYPO": "NOI",
            "N": "NP|NN|NR|NT|QP|CLP|PN|UCP",
            "LCP": "LCP|LC",
            "DNP": "DEG|DEC",
            "DVP": "DEV|DEG",
            "DP": "DP|CLP|QP",
            "CLP": "M|CLP",
            "ADJP": "ADJP|JJ",
            "ADVP": "ADVP|AD",
            "FRAG": "NN|VV",
            "INTJ": "IJ",
            "PP": "P$|PP",
            "PRN": "N|VP|IP|ADJP|QP|UCP",
            "QP": ("QP|CLP", "CD"),
            "UCP": "UCP|NP|IP|PP",
            "IP": ("IP|V", "NP|PP"),
            "CP": "CP|IP|V"
        }

        reverse = ("VCD|VCP|VNV|VPT|VRD|VSB", "N",
                   "CLP", "ADJP", "ADVP", "FRAG", "INTJ",
                   "PRN", "QP", "UCP", "IP", "CP")

        for current_label in template.keys():
            if re.match(current_label, t):
                boolReverse = (current_label in reverse)
                return self.tryFetchHead(template[current_label], reverse=boolReverse)

    def findDependencyParent(self):
        if self.terminal:
            t = self.parent
            while t.id == self.id:
                t = t.parent
                if t is None:
                    self.dependencyParent = 0
                    return
            self.dependencyParent = t.id
            return
        for son in self.children:
            son.findDependencyParent()

    def printDependencyConLL(self):
        if self.terminal:
            return "%d\t%s\t_\t_\t%s\t_\t%d\t_\t_\t_\n" % (
            self.id, self.label, self.parent.label, self.dependencyParent)
        s = ""
        for son in self.children:
            s += son.printDependencyConLL()
        return s


In [5]:
mytrees = []
s = ""
for i, tree in enumerate(trees):
    _mytree = MyTree(tree)
    _mytree.labelDependencyTree()
    _mytree.findDependencyParent()
    s += (posit[i] + '\n' + _mytree.printDependencyConLL() + '\n')
    mytrees += [_mytree]

In [6]:
f = codecs.open("output.conll", "w", "utf-8")
f.write(s)
f.close()

In [7]:
# useful function to check for exceptions
def traverse(tree, pat, cat, key=lambda x: x > 0):
    try:
        t = tree.label().split('-')[0]
    except AttributeError:
        return
    p = 0
    flag = 0
    for r in tree:
        try:
            if re.match(pat, r.label()):
                p += 1
        except AttributeError:
            return
    if re.match(cat, t) and len(tree) > 1 and key(p):
        print(tree)
        print("\n")
        return tree
    for r in tree:
        haha = traverse(r)
        if haha:
            return haha
    return