In [42]:
import enum
import collections
import re

In [43]:
from nltk.tokenize import word_tokenize
from nltk import pos_tag

In [None]:
def print_if_verbose(text, verbose):
    if verbose:
        print(text)

In [44]:
clean_re = re.compile("\s+([.,?!'])")


def clean(string):
    cleaned = string.replace("<s>", "").replace("<P>", "")
    cleaned = clean_re.sub("\\1", cleaned)
    return cleaned.strip()

In [45]:
def tokenize(s):
    return [t.lower() for t in word_tokenize(s)]


def get_first_token(s):
    tokenized = word_tokenize(s)
    if len(tokenized) == 0:
        return None
    else:
        return tokenized[0].lower()

In [46]:
def find_nested_relation(
    relation_type, relation, cur_depth=1, max_depth=3
):
    if relation is None or cur_depth > max_depth:
        return None, None
    elif relation_type == relation.type:
        return relation, cur_depth
    else:
        left_result, depth = find_nested_relation(
            relation_type, relation.left_child, cur_depth + 1
        )
        if left_result is None:
            return find_nested_relation(
                relation_type, relation.right_child, cur_depth + 1
            )
        else:
            return left_result, depth

In [47]:
# Testing

# %run ./relation_extraction.ipynb

# relation = Relation(
#     "Explanation", 
#     1, 
#     None,
#     Relation(
#         "Background",
#         2,
#         None,
#         None,
#         None
#     ),
#     None
# )

# print(find_nested_relation("Background", None))

# print(find_nested_relation("Explanation", relation))

# print(find_nested_relation("Background", relation))

# relation = Relation(
#     "Explanation", 
#     1, 
#     None,
#     Relation(
#         "Background",
#         2,
#         None,
#         None,
#         None
#     ),
#     Relation(
#         "Background",
#         3,
#         None,
#         Relation(
#             "Elaboration",
#             4,
#             None,
#             None,
#             None
#         ),
#         None
#     )
# )

# print(find_nested_relation("Background", relation))

# print(find_nested_relation("Elaboration", relation))

In [48]:
with open("aux/connectives.txt", "rt") as f:
    connectives = set(
        [line.strip() for line in f.readlines()]
    )

In [49]:
punctuation = set([".", ",", "!", "?", ";"])

In [50]:
def remove_trailing_punctuation(string):
    if len(string) == 0 or string[-1] not in punctuation:
        return string
    else:
        return string[:-1]

In [51]:
def remove_leading_punctuation(string):
    if len(string) == 0 or string[0] not in punctuation:
        return string
    else:
        return string[1:].strip()

In [65]:
def lowercase_first_letter(string):
    if "".join(word_tokenize(string)[:1]) != "I":
        return string[:1].lower() + string[1:]
    else:
        return string


def uppercase_first_letter(string):
    return string[:1].upper() + string[1:]

In [53]:
def trim_connective(string):
    lower = string.lower()
    for c in connectives:
        if lower[:len(c)] == c:
            if (
                c != "last" 
                or (len(lower) > len(c) and lower[len(c)] == ",")
            ):
                return remove_leading_punctuation(
                    string[len(c):]
                )
            else:
                return string
    return string

In [54]:
def search_segment(text, relation, direction, verbose=False):
    if direction == "left":
        child, segment = relation.left_child, relation.left
    else:
        child, segment = relation.right_child, relation.right
    print_if_verbose(text[segment.start:segment.end], verbose)
    if child:
        return search_segment(
            text, 
            child, 
            direction
        )
    else:
        return text[segment.start:segment.end]
    
    
def get_depth(relation):
    if relation:
        return 1 + max(
            get_depth(relation.left_child), 
            get_depth(relation.right_child)
        )
    else:
        return 0

In [63]:
RelationPartInfo = collections.namedtuple(
    "RelationPartInfo", ["direction", "relation", "segment"]
)


@enum.unique
class NucleusProximity(enum.Enum):
    NEAR = "near"
    FAR = "far"
    
    
class RelationInfo:
    def __init__(self, nucleus_info, satellite_info):
        self.nucleus_info = nucleus_info
        self.satellite_info = satellite_info


def get_info(relation, verbose=False):
    if relation.left.type == "N":
        print_if_verbose("Nucleus is on the left.", verbose)
        nucleus_info = RelationPartInfo(
            "right", relation.left_child, relation.left
        )
        satellite_info = RelationPartInfo(
            "left", relation.right_child, relation.right
        )
    else:
        print_if_verbose("Nucleus is on the right.", verbose)
        nucleus_info = RelationPartInfo(
            "left", relation.right_child, relation.right
        )
        satellite_info = RelationPartInfo(
            "right", relation.left_child, relation.left
        )
    return RelationInfo(
        nucleus_info, satellite_info
    )


def extract_nuclei(
    text,
    relation_info,
    satellite_nucleus_relation,
    satellite_nucleus_segment,
    verbose=False
):
    nucleus_info = relation_info.nucleus_info
    satellite_info = relation_info.satellite_info
    
    # if there are nested relations in the nucleus
    # and there are too many of them
    if nucleus_info.relation and get_depth(nucleus_info.relation) > 3:
        nucleus_text = clean( 
            search_segment( # take the closest nested segment relative to the satellite
                text, 
                nucleus_info.relation,
                nucleus_info.direction
            )
        )
    else: # else take the nucleus as is
        nucleus_text = clean(
            text[nucleus_info.segment.start:nucleus_info.segment.end]
        )   

    # checking if the nucleus of the satellite is on the right/left
    if satellite_info.relation.left.type == "N":
        print_if_verbose("Satellite's nucleus is on the left.", verbose)
        if nucleus_info.direction == "right": # if the expl. nucleus is on the left
            nucleus_proximity = NucleusProximity.NEAR
        else:
            nucleus_proximity = NucleusProximity.FAR
    else:
        print_if_verbose("Satellite's nucleus is on the right.", verbose)
        if nucleus_info.direction == "right":
            nucleus_proximity = NucleusProximity.FAR
        else:
            nucleus_proximity = NucleusProximity.NEAR
    print_if_verbose(f"Nuclei proximity is {nucleus_proximity}", verbose)

    # if there are nested relations in the satellite nucleus
    # and there are too many of them
    if (
        satellite_nucleus_relation 
        and get_depth(satellite_nucleus_relation) > 3
    ):
        satellite_nucleus_text = clean(
            search_segment(# take the closest nested segment relative to the nucleus
                text, 
                satellite_nucleus_relation, 
                satellite_info.direction,
                verbose
            )
        )
    else: # else take the satellite nucleus as is
        print_if_verbose(
            "Satellite doesn't have nested relations or its depth is too small.", verbose
        )
        satellite_nucleus_text = clean(
            text[
                satellite_nucleus_segment.start
                :satellite_nucleus_segment.end
            ]
        )

    return nucleus_text, satellite_nucleus_text, nucleus_proximity


class ExtendedRelationInfo(RelationInfo):
    def __init__(
        self, rel_info, nucleus_text, nucleus_proximity, sn_relation, sn_segment, sn_text
    ):
        RelationInfo.__init__(self, rel_info.nucleus_info, rel_info.satellite_info)
        self.nucleus_text = nucleus_text
        self.nucleus_proximity = nucleus_proximity
        self.sn_relation = sn_relation
        self.sn_segment = sn_segment
        self.sn_text = sn_text


wh_words = {
    "what",
    "when",
    "where",
    "why",
    "which",
    "how"
}


SatelliteNucleusCheckResult = collections.namedtuple(
    "SatelliteNucleusCheckResult", ["is_ok", "new_satellite_nucleus_text"]
)


def check_satellite_nucleus(satellite_nucleus_text, verbose=False):
    if len(satellite_nucleus_text) == 0:
        return SatelliteNucleusCheckResult(False, satellite_nucleus_text)
    
    sn_tokens = tokenize(satellite_nucleus_text)
    pointer = 1
    has_wh_word = False
    while pointer < len(sn_tokens):
        if sn_tokens[pointer] in wh_words:
            has_wh_word = True
            break
        pointer += 1
    if has_wh_word:
        print_if_verbose(
            "Satellite's nucleus contains a wh-word or 'how' in the middle "
            "and will be cut at its position.", 
            verbose
        )
        return SatelliteNucleusCheckResult(True, " ".join(sn_tokens[:pointer]) + ".")
    else:
        print_if_verbose("Satellite doesn't contain a wh-word or 'how'.", verbose)
        assert len(sn_tokens) > 0
        if sn_tokens[-1] in {".", "!", "?"}:
            return SatelliteNucleusCheckResult(True, satellite_nucleus_text)
        else:
            print_if_verbose("Satellite doesn't end with a punctuation mark.", verbose)
            return SatelliteNucleusCheckResult(False, satellite_nucleus_text)
    
    
def prepare_extended_info(text, relation, verbose=False):
    rel_info = get_info(relation, verbose)
    
    if rel_info.satellite_info.relation is None:
        print_if_verbose("Satellite doesn't have nested relations.", verbose)
        return None
    
    sn_relation, sn_segment = rel_info.satellite_info.relation.get_first_nucleus()
    
    if sn_segment is None:
        print_if_verbose("Failed to get the segment of the satellites' nucleus.", verbose)
        return None
        
    nucleus_text, satellite_nucleus_text, nucleus_proximity = extract_nuclei(
        text,
        rel_info,
        sn_relation,
        sn_segment,
        verbose
    )
    
    if len(nucleus_text) == 0:
        print_if_verbose("Nucleus is empty.", verbose)
        return None
    
    if len(satellite_nucleus_text) == 0:
        print_if_verbose("Satellite's nucleus is empty.", verbose)
        return None
    
    if nucleus_text[-1:] not in {".", "!", "?", ";"}:
        print_if_verbose("Nucleus doesn't end with a punctuation mark.", verbose)
        return None
    
    sn_check_result = check_satellite_nucleus(satellite_nucleus_text, verbose)
    if not sn_check_result.is_ok:
        return None
    
    return ExtendedRelationInfo(
        rel_info=rel_info,
        nucleus_text=nucleus_text,
        nucleus_proximity=nucleus_proximity,
        sn_relation=sn_relation,
        sn_segment=sn_segment,
        sn_text=sn_check_result.new_satellite_nucleus_text
    )

In [62]:
def get_relation_type(relation):
    if relation is None:
        return "-"
    else:
        return relation.type

In [4]:
def remove_extra_space(s):
    return " ".join(s.split())


def contains_any_of(s, s_array):
    for other_s in s_array:
        if other_s in s:
            return True
    return False