In [1]:
from ete3 import Tree
import re
import json
import codecs
from syntax_tree import Syntax_tree
from constituent import Constituent

In [3]:
def to_newick_format(parse_tree):
    parse_tree = parse_tree.replace(",", "*COMMA*")
    parse_tree = parse_tree.replace(":", "*COLON*")
    tree_list = load_syntax_tree(parse_tree)
    if tree_list == None:
        return None
    tree_list = tree_list[1] #去 root
    s = syntax_tree_to_newick(tree_list)
    s = s.replace(",)",")")
    if s[-1] == ",":
        s = s[:-1] + ";"
    return s

def load_syntax_tree(raw_text):
    stack = ["ROOT"]
    text = re.sub(r"\(", " ( ", raw_text)
    text = re.sub(r"\)", " ) ", text)
    text = re.sub(r"\n", " ", text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"^\(\s*\(\s*", "", text)
    text = re.sub(r"\s*\)\s*\)$", "", text)
    for c in text.strip(" ").split(" "):
        if c == ")":
            node = []
            while(1):
                popped = stack.pop()
                if popped == "(":
                    break
                node.append(popped)
            node.reverse()
            if len(node) > 1:
                stack.append(node)
            else:
                if node == []:
                    return None
                stack.append(node[0])
        else:
            stack.append(c)
    return stack

def syntax_tree_to_newick(syntax_tree):
    s = "("
    for child in syntax_tree[1:]:
        if not isinstance(child,list):
            s += child
        else:
            s += syntax_tree_to_newick(child)
    s += ")" + str(syntax_tree[0]) + ","
    return s


def get_all_tree(parse_tree):
    parse_tree_text = to_newick_format(parse_tree)
    tree = Tree(parse_tree_text, format=1)
    treelist = []
    tree_dict = {o:str(i) for i,o in enumerate(tree.get_leaves())}
    return [[int(i) for i in o.split()] for o in set(_get_all_tree(tree, treelist, tree_dict))]

def _get_all_tree(tree, treelist, tree_dict):
    punct = ['.', ',']
    treelist.append(' '.join([tree_dict[o] for o in tree.get_leaves() if str(o).split('-')[-1] not in punct]))
    if tree.get_children() == []:
        return treelist
    else:
        for child in tree.get_children():
            treelist = _get_all_tree(child, treelist, tree_dict)
        return treelist
    

def merge3dicts(x, y, z):
    m = x
    m.update(y)
    m.update(z)
    return m

def get_related_doc(parse_data, docid):
    ret = []
    for i, r in enumerate(parse_data):
        if r['DocID'] == docid:
            ret.append(r)
    return ret

def _get_constituents(parse_dict, DocID, sent_index, conn_index):
    parse_tree = parse_dict[DocID]["sentences"][sent_index]["parsetree"].strip()
    syntax_tree = Syntax_tree(parse_tree)
    if syntax_tree.tree == None:
        return []
    conn_indices = conn_index
    constituent_nodes = []
    if len(conn_indices) == 1:# like and or so...
        conn_node = syntax_tree.get_leaf_node_by_token_index(conn_indices[0]).up
    else:
        conn_node = syntax_tree.get_common_ancestor_by_token_indices(conn_indices)
        conn_leaves = set([syntax_tree.get_leaf_node_by_token_index(conn_index) for conn_index in conn_indices])
        children = conn_node.get_children()
        for child in children:
            leaves = set(child.get_leaves())
            if conn_leaves & leaves == set([]):
                constituent_nodes.append(child)

    curr = conn_node
    while not curr.is_root():
        constituent_nodes.extend(syntax_tree.get_siblings(curr))
        curr = curr.up

    # obtain the Constituent object according to the node.
    constituents = []
    for node in constituent_nodes:
        cons = Constituent(syntax_tree, node)
#         cons.connective = connective
        constituents.append(cons)
    return constituents

In [163]:
# parse_tree = "( (S (S (NP (DT Some)) (VP (MD may) (VP (VB have) (VP (VBN forgotten))))) (: --) (CC and) (S (NP (DT some) (JJR younger) (NNS ones)) (VP (MD may) (ADVP (RB never)) (VP (VB have) (ADJP (JJ experienced)) (: --) (SBAR (WHNP (WP what)) (S (NP (PRP it)) (VP (VBZ 's) (VP (VB like) (S (VP (TO to) (VP (VB invest) (PP (IN during) (NP (DT a) (NN recession))))))))))))) (. .)) )"
# all_tree = get_all_tree(parse_tree)

In [64]:
conll_train = '/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-train/pdtb-parses.json'
parse_dict_train = json.loads(codecs.open(conll_train, encoding='utf-8', errors='ignore').read())
conll_dev = '/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-dev/pdtb-parses.json'
parse_dict_dev = json.loads(codecs.open(conll_dev, encoding='utf-8', errors='ignore').read())
conll_test = '/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-test/pdtb-parses.json'
parse_dict_test = json.loads(codecs.open(conll_test, encoding='utf-8', errors='ignore').read())
parse_dict = merge3dicts(parse_dict_train, parse_dict_dev, parse_dict_test)

parse_data_path = "/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-train/relations.json"
parse_data_dev_path = '/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-dev/relations.json'
parse_data_test_path = '/home/pengfei/data/2015-2016_conll_shared_task/data/conll16st-en-03-29-16-test/relations.json'
parse_data = [json.loads(line) for line in codecs.open(parse_data_path).readlines()]
parse_data_dev = [json.loads(line) for line in codecs.open(parse_data_dev_path).readlines()]
parse_data_test = [json.loads(line) for line in codecs.open(parse_data_test_path).readlines()]
parse_data.extend(parse_data_dev)
parse_data.extend(parse_data_test)

In [14]:
cons = _get_constituents(parse_dict, 'wsj_0279', 11)        

In [65]:
true = 0
false = 0
length = 0
count = 0
for r in parse_data[:10000]:
    if r['Type'] == 'Explicit':
        sent_index = list(set([o[3] for o in r['Arg2']['TokenList']]))
        if len(sent_index) == 1:
            sent_index = sent_index[0]
            conn_indices = [o[4] for o in r['Connective']['TokenList']]
            constituents = _get_constituents(parse_dict, r['DocID'], sent_index, conn_indices)
            constituents = sorted(constituents, key=lambda constituent: constituent.indices[0])   # sort by age
            first_level = [constituents[i].indices for i in range(len(constituents))]
#             second_level = [constituents[i].indices + constituents[i+1].indices for i in range(len(constituents)-1)]
            second_level = [constituents[i].indices + constituents[j].indices for i in range(len(constituents)) for j in range(len(constituents)) if i<j]
            third_level = []
            for i in range(2, len(constituents)):
                for j in range(1, i):
                    for k in range(j):
                        third_level.append(constituents[k].indices + constituents[j].indices + constituents[i].indices)
#             third_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices for i in range(len(constituents)-2)]
            fourth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices for i in range(len(constituents)-3)]
            fifth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices for i in range(len(constituents)-4)]
            sixth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices + constituents[i+5].indices for i in range(len(constituents)-5)]
            constituents = first_level + second_level + third_level + fourth_level + fifth_level + sixth_level
            length += len(constituents)
            count += 1
            token_index = [o[4] for o in r['Arg2']['TokenList']]
            if token_index in constituents:
                true += 1
            else:
                false += 1

true, false, true / (true + false), length / count
            

(3874, 623, 0.8614631976873471, 111.72203691349789)

In [157]:
true / (true + false)

0.6274509803921569

In [4]:
conll_train = '/home/pengfei/data/PDTB-3.0/all/conll/train/pdtb-parses.json'
parse_dict_train = json.loads(codecs.open(conll_train, encoding='utf-8', errors='ignore').read())
conll_dev = '/home/pengfei/data/PDTB-3.0/all/conll/dev/pdtb-parses.json'
parse_dict_dev = json.loads(codecs.open(conll_dev, encoding='utf-8', errors='ignore').read())
conll_test = '/home/pengfei/data/PDTB-3.0/all/conll/test/pdtb-parses.json'
parse_dict_test = json.loads(codecs.open(conll_test, encoding='utf-8', errors='ignore').read())
print("datasets loaded")
parse_dict = merge3dicts(parse_dict_train, parse_dict_dev, parse_dict_test)

parse_data_path = "/home/pengfei/data/PDTB-3.0/all/conll/train/relations.json"
parse_data_dev_path = '/home/pengfei/data/PDTB-3.0/all/conll/dev/relations.json'
parse_data_test_path = '/home/pengfei/data/PDTB-3.0/all/conll/test/relations.json'
parse_data = [json.loads(line) for line in codecs.open(parse_data_path).readlines()]
parse_data_dev = [json.loads(line) for line in codecs.open(parse_data_dev_path).readlines()]
parse_data_test = [json.loads(line) for line in codecs.open(parse_data_test_path).readlines()]
parse_data.extend(parse_data_dev)
parse_data.extend(parse_data_test)

datasets loaded


In [8]:
true = 0
false = 0
length = 0
count = 0
for r in parse_data[:1000]:
    if r['Type'] == 'Explicit':
        sent_index = list(set([o[3] for o in r['Arg2']['TokenList']]))
        if len(sent_index) == 1 and r['Connective']['TokenList'][0][3] == sent_index[0]:
            sent_index = sent_index[0]
            conn_indices = [o[4] for o in r['Connective']['TokenList']]
            constituents = _get_constituents(parse_dict, r['DocID'], sent_index, conn_indices)
            constituents = sorted(constituents, key=lambda constituent: constituent.indices[0])   # sort by age
            first_level = [constituents[i].indices for i in range(len(constituents))]
    #             second_level = [constituents[i].indices + constituents[i+1].indices for i in range(len(constituents)-1)]
    #             third_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices for i in range(len(constituents)-2)]
    #             fourth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices for i in range(len(constituents)-3)]
    #             fifth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices for i in range(len(constituents)-4)]
    #             sixth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices + constituents[i+5].indices for i in range(len(constituents)-5)]
            second_level = [constituents[i].indices + constituents[j].indices for i in range(len(constituents)) for j in range(len(constituents)) if i<j]
    #             third_level = []
    #             for i in range(2, len(constituents)):
    #                 for j in range(1, i):
    #                     for k in range(j):
    #                         third_level.append(constituents[k].indices + constituents[j].indices + constituents[i].indices)
            constituents = first_level + second_level #+ third_level + fourth_level + fifth_level + sixth_level
            length += len(constituents)
            count += 1
            token_index = [o[4] for o in r['Arg2']['TokenList']]
            if token_index in constituents:
                true += 1
            else:
                false += 1
    #                 print(token_index)
    #                 print()
    #                 dis = [len(set(token_index).symmetric_difference(set(o))) for o in constituents]
    #                 print(constituents[dis.index(min(dis))])
    #                 print("==========")

true, false, true / (true + false), length / count, count

(357, 98, 0.7846153846153846, 36.417582417582416, 455)

In [None]:
true = 0
false = 0
length = 0
count = 0
for r in parse_data[:1000]:
    if r['Type'] == 'Explicit':
        sent_index = list(set([o[3] for o in r['Arg1']['TokenList']]))
        if len(sent_index) == 1 and r['Connective']['TokenList'][0][3] == sent_index[0]:
            sent_index = sent_index[0]
            conn_indices = [o[4] for o in r['Connective']['TokenList']]
            constituents = _get_constituents(parse_dict, r['DocID'], sent_index, conn_indices)
            constituents = sorted(constituents, key=lambda constituent: constituent.indices[0])   # sort by age
            first_level = [constituents[i].indices for i in range(len(constituents))]
    #             second_level = [constituents[i].indices + constituents[i+1].indices for i in range(len(constituents)-1)]
    #             third_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices for i in range(len(constituents)-2)]
    #             fourth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices for i in range(len(constituents)-3)]
    #             fifth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices for i in range(len(constituents)-4)]
    #             sixth_level = [constituents[i].indices + constituents[i+1].indices + constituents[i+2].indices + constituents[i+3].indices + constituents[i+4].indices + constituents[i+5].indices for i in range(len(constituents)-5)]
            second_level = [constituents[i].indices + constituents[j].indices for i in range(len(constituents)) for j in range(len(constituents)) if i<j]
    #             third_level = []
    #             for i in range(2, len(constituents)):
    #                 for j in range(1, i):
    #                     for k in range(j):
    #                         third_level.append(constituents[k].indices + constituents[j].indices + constituents[i].indices)
            constituents = first_level + second_level #+ third_level + fourth_level + fifth_level + sixth_level
            length += len(constituents)
            count += 1
            token_index = [o[4] for o in r['Arg2']['TokenList']]
            if token_index in constituents:
                true += 1
            else:
                false += 1
    #                 print(token_index)
    #                 print()
    #                 dis = [len(set(token_index).symmetric_difference(set(o))) for o in constituents]
    #                 print(constituents[dis.index(min(dis))])
    #                 print("==========")

true, false, true / (true + false), length / count, count

In [1]:
import sys
sys.path.append('..')
from pdtb_api.api import PDTB3

In [2]:
pdtb3 = PDTB3()

In [4]:
for i in range(100):
    print(pdtb3.get_highlighted_relation(i, verbose=True))

DocID: wsj_0793,	Type: EntRel,	Sense: EntRel
[7mGround[0m [7mzero[0m [7mof[0m [7mthe[0m [7mHUD[0m [7mscandal[0m [7mis[0m [7mthe[0m [7mSecretary[0m [7m's[0m [7m``[0m [7mdiscretionary[0m [7mfund[0m [7m,[0m [7m''[0m [7ma[0m [7mhoney[0m [7mpot[0m [7mused[0m [7mto[0m [7mfund[0m [7mprojects[0m [7mthat[0m [7mwere[0m [7mn't[0m [7mapproved[0m [7mthrough[0m [7mnormal[0m [7mHUD[0m [7mchannels[0m . [1mJack[0m [1mKemp[0m [1mwants[0m [1mto[0m [1mabolish[0m [1mit[0m . 
DocID: wsj_0793,	Type: Explicit,	Sense: Expansion.Substitution.Arg2-as-subst
[7mJack[0m [7mKemp[0m [7mwants[0m [7mto[0m [7mabolish[0m [7mit[0m [4mInstead[0m [1mCongress[0m [1m's[0m [1midea[0m [1mof[0m [1mreform[0m [1mis[0m [1mto[0m [1mincrease[0m [1mthis[0m [1mslush[0m [1mfund[0m [1mby[0m [1m$[0m [1m28.4[0m [1mmillion[0m 
DocID: wsj_0793,	Type: Implicit,	Sense: Temporal.Asynchronous.Precedence
[7mInstead[0m [7m,[0m [

In [7]:
print('\x1b[7mFollowing\x1b[0m')

[7mFollowing[0m


In [4]:
tree = pdtb3.get_syntax_tree('wsj_1000', 3)

In [8]:
print(str(tree.tree))


      /- /-Over
     |
   /-|   /- /-the
  |  |  |
  |  |  |-- /-past
  |   \-|
  |     |-- /-nine
  |     |
  |      \- /-months
  |
  |-- /-*COMMA*
  |
  |      /- /-several
  |   /-|
  |  |   \- /-firms
  |  |
  |  |-- /-*COMMA*
  |  |
  |  |   /- /-including
  |  |  |
  |  |  |      /- /-discount
  |  |  |     |
  |  |  |     |-- /-broker
  |  |  |     |
  |  |  |     |-- /-Charles
  |  |  |   /-|
  |  |--|  |  |-- /-Schwab
  |  |  |  |  |
  |  |  |  |  |-- /-&
  |  |  |  |  |
  |  |  |  |   \- /-Co.
  |  |  |  |
  |  |  |  |-- /-and
  |  |  |  |
  |--|  |  |      /- /-Sears
  |  |   \-|     |
  |  |     |     |-- /-*COMMA*
  |  |     |     |
  |  |     |     |-- /-Roebuck
--|  |     |   /-|
  |  |     |  |  |-- /-&
  |  |     |  |  |
  |  |     |  |  |-- /-Co.
  |  |     |  |  |
  |  |     |  |   \- /-'s
  |  |     |  |
  |  |      \-|-- /-Dean
  |  |        |
  |  |        |-- /-Witter
  |  |        |
  |  |        |-- /-Reynolds
  |  |        |
  |  |        |-- /-Inc.
  |  |  

In [50]:
class color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    BACKGROUND = '\033[7M'
    END = '\033[0m'

print(color.BOLD + 'Hello World !' + color.END)
print('Hello World !')

[1mHello World ![0m
Hello World !


In [14]:
print ("\x1B[3mHello World\x1B[23m")

[3mHello World[23m


In [16]:
print(''hello word'')

SyntaxError: invalid syntax (<ipython-input-16-50453076d405>, line 1)

In [51]:
for i in range(200):
    print(i, '\033[' + str(i) +'m' + 'Hello World !' + color.END)

0 [0mHello World ![0m
1 [1mHello World ![0m
2 [2mHello World ![0m
3 [3mHello World ![0m
4 [4mHello World ![0m
5 [5mHello World ![0m
6 [6mHello World ![0m
7 [7mHello World ![0m
8 [8mHello World ![0m
9 [9mHello World ![0m
10 [10mHello World ![0m
11 [11mHello World ![0m
12 [12mHello World ![0m
13 [13mHello World ![0m
14 [14mHello World ![0m
15 [15mHello World ![0m
16 [16mHello World ![0m
17 [17mHello World ![0m
18 [18mHello World ![0m
19 [19mHello World ![0m
20 [20mHello World ![0m
21 [21mHello World ![0m
22 [22mHello World ![0m
23 [23mHello World ![0m
24 [24mHello World ![0m
25 [25mHello World ![0m
26 [26mHello World ![0m
27 [27mHello World ![0m
28 [28mHello World ![0m
29 [29mHello World ![0m
30 [30mHello World ![0m
31 [31mHello World ![0m
32 [32mHello World ![0m
33 [33mHello World ![0m
34 [34mHello World ![0m
35 [35mHello World ![0m
36 [36mHello World ![0m
37 [37mHello World ![0m
38 [38mHello World ![0m
39 [3