In [139]:
import pandas as pd
from pathlib import Path
import os
from ast import literal_eval as load

def get_entity_span(annotations: list) -> list:
    """
    Given a sequence of BIO annotations, get the list of tuples representing spans of entities
    :param annotations: BIO annotation
    :return: A list of the span of entities [start, end)
    """

    types = ['-p', '-n']
    span = []
    for tp in types:
        start = 0
        while start < len(annotations):
            if annotations[start] == 'B' + tp:
                for end in range(start + 1, len(annotations)):
                    if annotations[end] != 'I' + tp:
                        span.append((start, end))
                        break
                else:
                    span.append((start, len(annotations)))
            start += 1
    span = sorted(span, key=lambda x: x[0])
    return span

def output_submission(entity_data: pd.DataFrame or str, triple_data: pd.DataFrame or str, out_dir: str):
    """
    Format the output from prediction into the final output form.
    :param entity_data: The dataframe prediction of entities or path to csv file
    :param triple_data: The dataframe prediction of triples or path to csv file
    :param out_dir: Output directory
    """
    if type(entity_data) is str:
        entity_data: pd.DataFrame = pd.read_csv(entity_data)
    if type(triple_data) is str:
        triple_data: pd.DataFrame = pd.read_csv(triple_data)
    triple_list = ['triple_A', 'triple_B', 'triple_C', 'triple_D']
    triple_frame = triple_data[['paper_idx', 'topic', 'labels', 'subj/obj'] + triple_list]
    entity_frame = entity_data[['text', 'topic', 'paper_idx', 'idx', 'BIO_1']]
    topics = list(entity_frame['topic'].drop_duplicates())
    for topic in topics:
        topic_entity = entity_frame[entity_frame['topic'] == topic]
        topic_triple = triple_frame[triple_frame['topic'] == topic]
        paper_indices = topic_entity['paper_idx'].drop_duplicates()
        for paper_index in paper_indices:
            entity_df = topic_entity[topic_entity['paper_idx'] == paper_index]
            triple_df = topic_triple[topic_triple['paper_idx'] == paper_index]
            dir_path = os.path.join(out_dir, topic, str(paper_index))
            Path(os.path.join(dir_path, 'triples')).mkdir(parents=True, exist_ok=True)
            entities = ''
            sentences = set()
            for _, row in triple_df.iterrows():
                triples_text = ''
                try:
                    info_unit = row['labels']
                except AttributeError:
                    continue
                if info_unit not in ['code', 'research-problem']:
                    for triple in triple_list:
                        for t in load(row[triple]):
                            triples_text += f"({'||'.join(t)})\n"
                elif info_unit=='code':
                    for p in load(row['subj/obj']): 
                        triples_text += f"(Contribution||Code||{p[0]})\n"
                elif info_unit=='research-problem':
                    for p in load(row['subj/obj']):
                        triples_text += f"(Contribution||has research problem||{p[0]})\n"
                with open(os.path.join(dir_path, 'triples', f'{info_unit}.txt'), 'a+') as f_triple:
                    f_triple.write(triples_text)
            # Add specific tuple
            for tuple_file in os.listdir(os.path.join(dir_path, 'triples')):
                info_unit = str(tuple_file).split('.')[0].replace('-', ' ')
                if info_unit not in ['code', 'research problem']:
                    with open(os.path.join(dir_path, 'triples', tuple_file), 'a') as f_triple:
                        f_triple.write(f"(Contribution||has||{info_unit[0].upper()+info_unit[1:]})\n")
            # Getting entities
            for _, row in entity_df.iterrows():
                idx = row['idx']
                bio = load(row['BIO_1'])
                text = row['text']
                words = text.split()
                spans = get_entity_span(bio)
                sentences.add(str(idx))
                for st, ed in spans:
                    start_idx = sum([len(i) + 1 for i in words[:st]])
                    end_idx = sum([len(i) + 1 for i in words[:ed]]) - 1
                    phrase = ' '.join(words[st: ed])
                    entities += f'{idx}\t{start_idx}\t{end_idx}\t{phrase}\n'
            with open(os.path.join(dir_path, 'entities.txt'), 'w+') as f_entity:
                f_entity.write(entities)
            with open(os.path.join(dir_path, 'sentences.txt'), 'w+') as f_sentence:
                f_sentence.write('\n'.join(sentences))


In [140]:
!rm -rf submission

In [141]:
output_submission('pos_sent.csv','triples.csv','submission')