In [2]:
import pandas as pd
from tqdm import tqdm

In [3]:
DATA_PATH = "../../data/parsed_data"

ner_train_path = f"{DATA_PATH}/ner_train.parquet"
rel_train_path = f"{DATA_PATH}/rel_train.parquet"

ner_test_path = f"{DATA_PATH}/ner_test.parquet"
rel_test_path = f"{DATA_PATH}/rel_test.parquet"

In [9]:
class FlairDataWriter:
    def __init__(self, ner_data_path, output_path, rel_data_path=None):
        self.ner_data = pd.read_parquet(ner_data_path)
        self.output_path = output_path
        self.rel_data = None
        if rel_data_path:
            self.rel_data = pd.read_parquet(rel_data_path)
            self.rel_data['UUID'] = range(self.rel_data.shape[0])

    def __collapse(self, df):
        token_str = "|".join(df['token'].to_list())
        tag_str = "|".join(df['tag'].to_list())
        return pd.DataFrame([[token_str, tag_str]], columns=['text', 'tags'])

    def save_rel_data(self, name):
        assert self.rel_data is not None, "Provide Relations Data!"
        ner_collapsed = self.ner_data[self.ner_data['contains_rel']==1].groupby(
            ['uid', 'sid']
        ).apply(self.__collapse).droplevel(2).reset_index()
        
        merged_data = pd.merge(
            self.rel_data, ner_collapsed, on=['uid', 'text'], how='left'
        ).drop(columns=['sid_x', 'sid_y'])

        assert merged_data.isnull().sum().sum() == 0
        assert merged_data.shape[0] == self.rel_data.shape[0]

        with open(f"{self.output_path}/{name}.txt", "w") as file:
            file.write("# global.columns = id form tag\n")
            for idx, row in tqdm(merged_data.iterrows(), total=merged_data.shape[0]):
                words = row["text"].split("|")
                tags = row["tags"].split("|")
                sent = " ".join(words)
                file.write(f"# text = {sent}\n")
                file.write(f"# sentence_id = {row['UUID']}\n")
                a1 = row["arg1"].split(":")
                a2 = row["arg2"].split(":")
                file.write(f'# relations = {a1[0]};{int(a1[1])-1};{a2[0]};{int(a2[1])-1};{row["label"]}\n')
                for i, (word, tag) in enumerate(zip(words, tags)):
                    file.write(f"{i} {word} {tag}\n")
                file.write("\n")

    def save_ner_data(self, name):
        with open(f"{self.output_path}/{name}.txt", "w") as file:
            curr_sid = 0
            for idx, row in tqdm(self.ner_data.iterrows(), total=self.ner_data.shape[0]):
                if row["sid"] != curr_sid:
                    curr_sid = row["sid"]
                    file.write("\n")
                file.write(row["token"] + " " + row["tag"] + "\n")

In [24]:
flair_train_writer = FlairDataWriter(ner_train_path, DATA_PATH, rel_train_path)

In [21]:
flair_train_writer.save_ner_data('flair_ner_train')

100%|██████████| 1435233/1435233 [01:47<00:00, 13312.49it/s]


In [25]:
flair_train_writer.save_rel_data('flair_rel_train')

100%|██████████| 36346/36346 [00:05<00:00, 6306.49it/s]


In [10]:
flair_test_writer = FlairDataWriter(ner_test_path, DATA_PATH, rel_test_path)

In [7]:
flair_test_writer.save_ner_data('flair_ner_test')

100%|██████████| 931604/931604 [01:11<00:00, 13047.09it/s]


In [12]:
flair_test_writer.save_rel_data('flair_rel_test')

100%|██████████| 23462/23462 [00:03<00:00, 6450.52it/s]
