In [11]:
import os
import re
from collections import namedtuple
from enum import Enum

In [24]:
Segment = namedtuple("Segment", ["type", "start", "end"])


class Relation:
    def __init__(
        self,
        relation_type,
        left_segment,
        right_segment,
        left_relation,
        right_relation
    ):
        self.type = relation_type
        self.left = left_segment
        self.right = right_segment
        self.left_child = left_relation
        self.right_child = right_relation
        
    def get_first_nucleus(self):
        if self.left.type == "N":
            return self.left_child, self.left
        else:
            return self.right_child, self.right
        
    def get_satellite(self):
        if self.left.type == "S":
            return self.left_child, self.left
        elif self.right.type == "S":
            return self.right_child, self.right
        else:
            return None, None

In [13]:
def skip_whitespace(text, pointer):
    cnt = 0
    while pointer + cnt < len(text) and text[pointer + cnt].isspace():
        cnt += 1
    return cnt

In [14]:
def check_symbol(text, pointer, c):
    if pointer >= len(text):
        return False
    else:
        return text[pointer] == c

In [15]:
def is_segment_start(tree_text, pointer):
    return (
        check_symbol(tree_text, pointer, '_')
        and check_symbol(tree_text, pointer + 1, '!')
    )


def is_segment_end(tree_text, pointer):
    return (
        check_symbol(tree_text, pointer, '!')
        and check_symbol(tree_text, pointer + 1, '_')
    )

In [16]:
head_re = re.compile("([^[]*)\[(N|S)\]\[(N|S)\]")

def read_head(tree_text, pointer):
    pointer += skip_whitespace(tree_text, pointer)
    assert(check_symbol(tree_text, pointer, '('))
    pointer += 1
    
    head_start = pointer
    while (
        pointer < len(tree_text) 
        and not tree_text[pointer].isspace()
    ):
        pointer += 1
    head_end = pointer
    
    relation_type, left_segment_type, right_segment_type = \
        head_re.match(tree_text[head_start:head_end]).groups()
    
    return (
        relation_type, 
        left_segment_type, 
        right_segment_type, 
        pointer
    )

In [17]:
def relation_to_segment(relation, segment_type):
    return Segment(
        segment_type, 
        relation.left.start, 
        relation.right.end
    )

In [18]:
def read_segment(tree_text, pointer, text, segment_type):
    pointer += skip_whitespace(tree_text, pointer)
    assert(
        is_segment_start(tree_text, pointer)
    )
    pointer += 2
    segment_start = len(text)
    while (
        pointer < len(tree_text)
        and not (
            check_symbol(tree_text, pointer, '!')
            and check_symbol(tree_text, pointer + 1, '_')
        )
    ):
        text.append(tree_text[pointer])
        pointer += 1
    assert(
        is_segment_end(tree_text, pointer)
    )
    text.append(" ")
    segment_end = len(text)
    pointer += 2
    return (
        Segment(
            segment_type,
            segment_start,
            segment_end
        ),
        pointer
    )

In [19]:
def read_relation_or_segment(
    tree_text, 
    pointer, 
    text, 
    segment_type,
    relations
):
    pointer += skip_whitespace(tree_text, pointer)
    if check_symbol(tree_text, pointer, '('):
        relation, pointer = read_relation(
            tree_text, 
            pointer,
            text,
            relations
        )
        segment = relation_to_segment(
            relation, segment_type
        )
        return segment, pointer, relation
    else:
        segment, pointer = read_segment(
            tree_text, pointer, text, segment_type
        )
        return segment, pointer, None

In [20]:
def read_relation(tree_text, pointer, text, relations):
    relation_type, left_segment_type, right_segment_type, pointer = \
        read_head(tree_text, pointer)
        
    left_segment, pointer, left_child = read_relation_or_segment(
        tree_text, pointer, text, left_segment_type, relations)
    
    right_segment, pointer, right_child = read_relation_or_segment(
        tree_text, pointer, text, right_segment_type, relations)
    
    pointer += skip_whitespace(tree_text, pointer)
    assert(check_symbol(tree_text, pointer, ')'))
    pointer += 1
    
    relation = Relation(
        relation_type,
        left_segment,
        right_segment,
        left_child,
        right_child
    )
    if not relation_type in relations:
        relations[relation_type] = []
    relations[relation_type].append(relation)
    
    return relation, pointer

In [21]:
def read_relations(tree_text):
    pointer = 0
    text = []
    relations = {}
    read_relation(tree_text, pointer, text, relations)
    return "".join(text), relations


def extract_relations(file_path):
    with open(file_path, "rt") as f:
        tree_text = f.read()

    return read_relations(tree_text)

In [29]:
def read_relation_tree(tree_text):
    root, _ = read_relation(tree_text, 0, [], {})
    return root


def extract_relation_tree(file_path):
    with open(file_path, "rt") as f:
        tree_text = f.read()

    return read_relation_tree(tree_text)

In [23]:
# tree_text = """
#     (Elaboration[N][S]
#        _!ha-ha !_
#        (Elaboration[N][S] _!this is a  segment !_ 
#        (Join[N][N] _!a .!_ _!b .!_)))
# """
# read_relations(tree_text)

('ha-ha  this is a  segment  a . b . ',
 {'Elaboration': [Relation(type='Elaboration', left=Segment(type='N', start=7, end=27), right=Segment(type='S', start=27, end=35), left_child=None, right_child=Relation(type='Join', left=Segment(type='N', start=27, end=31), right=Segment(type='N', start=31, end=35), left_child=None, right_child=None)),
   Relation(type='Elaboration', left=Segment(type='N', start=0, end=7), right=Segment(type='S', start=7, end=35), left_child=None, right_child=Relation(type='Elaboration', left=Segment(type='N', start=7, end=27), right=Segment(type='S', start=27, end=35), left_child=None, right_child=Relation(type='Join', left=Segment(type='N', start=27, end=31), right=Segment(type='N', start=31, end=35), left_child=None, right_child=None)))],
  'Join': [Relation(type='Join', left=Segment(type='N', start=27, end=31), right=Segment(type='N', start=31, end=35), left_child=None, right_child=None)]})

In [239]:
# text = []
# relations = {}
# read_relation_or_segment(
#     tree_text, 
#     0, 
#     text, 
#     SegmentType.SATELLITE, 
#     relations
# )

In [240]:
# text, relations = extract_relations(
#     "parsed/race/train/high/10324.txt.tree"
# )