In [10]:
import collections
import enum
import json
import os
import random

In [67]:
import nltk.tokenize

import pandas as pd

In [12]:
import import_ipynb
import aux.defs
import aux.relation_extraction
import aux.utils
import aux.nlp
import preparation

%run explanation_04.ipynb

In [47]:
STATEMENTS_DIR = "/Users/YK/mt/project/statements_3/"
RACE_PART = "train/middle"
RACE_DIR = "/Users/YK/mt/RACE"
PARSED_RACE_DIR = "/Users/YK/mt/parsed/race"

In [119]:
class Position(enum.Enum):
    BEFORE = "before"
    NESTED = "nested"
    AFTER  = "after"

    
Alternative = collections.namedtuple(
    "Alternative",
    [
        "true_statement", 
        "alternative_statement",
        "relation_type",
        "position", 
        "distance_words",
        "distance_sentences",
        "sn_length",
        "sn_length_relative_difference",
        "jaccard_distance",
        "edit_distance"
    ]
)

In [120]:
def load_statements(directory, subdirectories, text_no):
    statements = {}
    for subdirectory in subdirectories:
        file_path = os.path.join(
            directory, subdirectory, RACE_PART, f"{text_no}.txt.tree"
        )
        if os.path.exists(file_path): 
            with open(file_path, "rt") as f:
                statements[subdirectory] = json.load(f)
    return statements
    

def load_relations(text_no, directory):
    text, relations, _ = aux.relation_extraction.load_relations(
        os.path.join(directory, f"{text_no}.txt.tree")
    )
    return text, relations # {t: relations[t] for t in types if t in relations}

In [121]:
statements_subdirectories = [
    f for f in os.listdir(STATEMENTS_DIR) if os.path.isdir(os.path.join(STATEMENTS_DIR, f))
]

In [122]:
def get_n_words(text_span):
    return len(nltk.tokenize.word_tokenize(text_span))


def get_n_sentences(text_span):
    cnt = 0
    for c in text_span:
        if c in {'.', ',', '!'}:
            cnt += 1
    return cnt


def get_position_and_distance(statement, relation, text, verbose=False):
    if relation.right.end <= statement["left_boundary"]:
        span = text[relation.right.end:statement["left_boundary"]]
        return (
            Position.BEFORE,
            get_n_words(span),
            get_n_sentences(span)
        )
    elif relation.left.start >= statement["right_boundary"]:
        span = text[statement["right_boundary"]:relation.left.start]
        return (
            Position.AFTER,
            get_n_words(span),
            get_n_sentences(span)
        )
    else:
        if (
            relation.left.start < statement["split_point"]
                and relation.right.end > statement["split_point"]
        ):
            if verbose:
                print("The relation overlaps with the relation of the true statement.")
            return None, None, None
        else:
            if relation.right.end <= statement["split_point"]:
                span = text[relation.right.end:statement["split_point"]]
                return Position.NESTED, get_n_words(span), get_n_sentences(span)
            else:
                span = (text[statement["split_point"]:relation.left.start])
                return Position.NESTED, get_n_words(span), get_n_sentences(span)

    
def get_jaccard_distance(phrase_1, phrase_2):
    tokens_1 = set(nltk.tokenize.word_tokenize(phrase_1))
    tokens_2 = set(nltk.tokenize.word_tokenize(phrase_2))
    return nltk.jaccard_distance(tokens_1, tokens_2)


def get_edit_distance(phrase_1, phrase_2):
    return nltk.edit_distance(phrase_1, phrase_2)
    

RelationData = collections.namedtuple(
    "RelationData",
    ["relation", "position", "distance_words", "distance_sentences"]
)


def get_k(relation_data_list, closest, k):
    sorted_relation_data_list = sorted(
        relation_data_list, key=lambda rd: rd.distance_words
    )
    if closest:
        return sorted_relation_data_list[:k]
    else:
        return sorted_relation_data_list[-k:]
        
        
def filter_relations(statement, relations, text, k=2):
    relation_data_lists = collections.defaultdict(list)
    for relation in relations:
        position, distance_words, distance_sentences = get_position_and_distance(
            statement, relation, text
        )
        if position is not None:
            relation_data_lists[position].append(
                RelationData(
                    relation=relation,
                    position=position,
                    distance_words=distance_words,
                    distance_sentences=distance_sentences
                )
            )
    
    result = []
    result += get_k(relation_data_lists[Position.BEFORE], closest=True, k=k)
    result += get_k(relation_data_lists[Position.AFTER], closest=True, k=k)
    result += get_k(relation_data_lists[Position.NESTED], closest=False, k=k)
    return result
    
    
def create_alternative(statement, relation_data, true_sn_text_len, text, verbose=False):
    relation_info = preparation.get_info(relation_data.relation, verbose)
    assert relation_info is not None
    if relation_info.satellite_info.relation is None:
        if verbose:
            print("Satellite is flat.")
        return None

    satellite_handling_result = preparation.Preprocessor.handle_satellite(
        text, relation_info.satellite_info, relation_info.nucleus_info.direction, verbose
    )
    if satellite_handling_result is None:
        if verbose:
            print("Satellite preprocessing was unsuccessful.")
        return None
    
    processed_sn_text = aux.nlp.take_first_sentence_and_remove_leading_words(
        satellite_handling_result.preparation_result.prepared_text, verbose
    )
    prepared_sn_text = utils.lowercase_first_letter(
        processed_sn_text if processed_sn_text is not None 
            else info.satellite_preparation_result.prepared_text
    )
    sn_text_len = get_n_words(prepared_sn_text)
    
    true_statement_nucleus = statement["nucleus"]
    connective = statement["connective"]
    alternative_text = f"{true_statement_nucleus}{connective}{prepared_sn_text}"
    return Alternative(
        true_statement=statement["statement_text"],
        alternative_statement=alternative_text,
        position=relation_data.position.value,
        relation_type=relation_data.relation.type,
        distance_words=relation_data.distance_words,
        distance_sentences=relation_data.distance_sentences,
        sn_length=sn_text_len,
        sn_length_relative_difference=(sn_text_len / true_sn_text_len - 1),
        jaccard_distance=get_jaccard_distance(
            statement["satellite_nucleus"], prepared_sn_text
        ),
        edit_distance=get_edit_distance(
            statement["satellite_nucleus"], prepared_sn_text
        )
    )

In [134]:
def generate_alternatives(text_no):
    rows = []
    
    text, relation_map = load_relations(
        text_no, os.path.join(PARSED_RACE_DIR, RACE_PART)
    )
    relations = [
        relation for _, relations in relation_map.items() for relation in relations
    ]
    statement_map = load_statements(STATEMENTS_DIR, statements_subdirectories, text_no)
    for _, statements in statement_map.items():
        for statement in statements:
            true_sn_text_len = get_n_words(statement["satellite_nucleus"])
            filtered_relation_data = filter_relations(statement, relations, text)
            for relation_data in filtered_relation_data:
                alternative = create_alternative(
                    statement, relation_data, true_sn_text_len, text
                )
                if alternative is not None:
                    row_dict = alternative._asdict()
                    row_dict.update(
                        {
                            "text_no": text_no,
                            "rule": statement["rule"],
                            "reason": statement["reason"][1]
                        }
                    )
                    rows.append(
                        row_dict
                    )
    return rows


def create_df(rows):
    if len(rows) > 0:
        result_df = pd.DataFrame(rows)[
            [
                "text_no",
                "true_statement", 
                "alternative_statement",
                "relation_type",
                "position",        
                "distance_words",
                "distance_sentences",
                "sn_length",
                "sn_length_relative_difference",
                "jaccard_distance",
                "edit_distance",
                "rule",
                "reason"
            ]
        ]
        result_df["d"] = (
            result_df.distance_words 
                * (1 - 2 * (result_df.position == Position.NESTED).astype(int))
        )
        result_df.sort_values(
            ["text_no", "rule", "true_statement", "position", "d"], inplace=True
        )
        result_df.drop("d", 1, inplace=True)

        return result_df
    else:
        return None

In [144]:
text_numbers = [
    int(fn.split('.')[0]) for fn in os.listdir(os.path.join(RACE_DIR, RACE_PART))
        if fn[-4:] == ".txt"
]

In [146]:
rows = []

for text_no in text_numbers[:50]:
    rows.extend(generate_alternatives(text_no))

result_df = create_df(rows)

In [148]:
result_df.to_excel(
    os.path.join(
        STATEMENTS_DIR, 
        f"alternatives_{RACE_PART.replace('/', '-')}_{random.randint(0, 2**32):x}.xlsx"
    ),
    index=False
)