In [1]:
import pandas as pd
from tqdm import tqdm

In [2]:
DATA_PATH = "../../data/parsed_data"

ner_train_path = f"{DATA_PATH}/ner_train.parquet"
rel_train_path = f"{DATA_PATH}/rel_train.parquet"

ner_test_path = f"{DATA_PATH}/ner_test.parquet"
rel_test_path = f"{DATA_PATH}/rel_test.parquet"

In [3]:
pd.read_parquet(f"{DATA_PATH}/{rel_test_path}").head()

Unnamed: 0,text,arg1,arg2,label,sid,uid
0,"MEDICATIONS|:|Lipitor|,|Tylenol|with|Codeine|,...",13:14,12:13,Frequency-Drug,0,100130
0,"MEDICATIONS|:|Lipitor|,|Tylenol|with|Codeine|,...",17:19,12:13,Duration-Drug,1,100130
0,She|was|started|on|prophylactic|Oxacillin|to|c...,22:24,5:6,Reason-Drug,2,100130
0,"Unit|was|uneventful|,|and|she|was|discharged|t...",17:19,19:21,Frequency-Drug,3,100130
0,1|hour|Pred|Forte|application|to|the|eye|and|c...,7:8,2:4,Route-Drug,4,100130


In [12]:
class FlairDataWriter:
    """
    A class to write NER and Relation extraction data in Flair format.

    Args:
        ner_data_path (str): The path to NER data in parquet format.
        output_path (str): The path to save the output files.
        rel_data_path (str, optional): The path to Relation extraction data in parquet format. Defaults to None.

    Attributes:
        ner_data (DataFrame): A pandas DataFrame containing the NER data.
        output_path (str): The path to save the output files.
        rel_data (DataFrame): A pandas DataFrame containing the relation extraction data. Default is None.

    Methods:
        __collapse(df): A private method that collapses a DataFrame into a single row.
        save_rel_data(name): A method that saves the relation extraction data in Flair format.
        save_ner_data(name): A method that saves the NER data in Flair format.
    """
    def __init__(self, ner_data_path, output_path, rel_data_path=None):
        """
        Initializes the FlairDataWriter object.

        Args:
            ner_data_path (str): The path to NER data in parquet format.
            output_path (str): The path to save the output files.
            rel_data_path (str, optional): The path to Relation extraction data in parquet format. Defaults to None.
        """
        self.ner_data = pd.read_parquet(ner_data_path)
        self.output_path = output_path
        self.rel_data = None
        if rel_data_path:
            self.rel_data = pd.read_parquet(rel_data_path)
            self.rel_data['UUID'] = range(self.rel_data.shape[0])

    def __collapse(self, df):
        """
        A private method that collapses a DataFrame into a single row.

        Args:
            df (DataFrame): The pandas DataFrame to be collapsed.

        Returns:
            DataFrame: A pandas DataFrame with a single row containing the collapsed data.
        """
        token_str = "|".join(df['token'].to_list())
        tag_str = "|".join(df['tag'].to_list())
        return pd.DataFrame([[token_str, tag_str]], columns=['text', 'tags'])

    def save_rel_data(self, name):
        """
        Saves the relation extraction data in Flair format.

        Args:
            name (str): The name of the file to be saved.
        """
        assert self.rel_data is not None, "Provide Relations Data!"
        ner_collapsed = self.ner_data[self.ner_data['contains_rel']==1].groupby(
            ['uid', 'sid']
        ).apply(self.__collapse).droplevel(2).reset_index()
        
        merged_data = pd.merge(
            self.rel_data, ner_collapsed, on=['uid', 'text'], how='left'
        ).drop(columns=['sid_x', 'sid_y'])

        assert merged_data.isnull().sum().sum() == 0
        assert merged_data.shape[0] == self.rel_data.shape[0]

        with open(f"{self.output_path}/{name}.txt", "w") as file:
            file.write("# global.columns = id form ner\n")
            for idx, row in tqdm(merged_data.iterrows(), total=merged_data.shape[0]):
                words = row["text"].split("|")
                tags = row["tags"].split("|")
                sent = " ".join(words)
                file.write(f"# text = {sent}\n")
                file.write(f"# sentence_id = {row['UUID']}\n")
                a1 = row["arg1"].split(":")
                a2 = row["arg2"].split(":")
                file.write(f'# relations = {int(a1[0])+1};{int(a1[1])};{int(a2[0])+1};{int(a2[1])};{row["label"]}\n')
                for i, (word, tag) in enumerate(zip(words, tags)):
                    file.write(f"{i+1} {word} {tag}\n")
                file.write("\n")

    def save_ner_data(self, name):
        """
        Saves the NER data in Flair format.

        Args:
            name (str): The name of the file to be saved.
        """
        with open(f"{self.output_path}/{name}.txt", "w") as file:
            curr_sid = 0
            for idx, row in tqdm(self.ner_data.iterrows(), total=self.ner_data.shape[0]):
                if row["sid"] != curr_sid:
                    curr_sid = row["sid"]
                    file.write("\n")
                file.write(row["token"] + " " + row["tag"] + "\n")

In [13]:
flair_train_writer = FlairDataWriter(ner_train_path, DATA_PATH, rel_train_path)

In [21]:
flair_train_writer.save_ner_data('flair_ner_train')

100%|██████████| 1435233/1435233 [01:47<00:00, 13312.49it/s]


In [14]:
flair_train_writer.save_rel_data('flair_rel_train')

100%|██████████| 36346/36346 [00:07<00:00, 5131.04it/s]


In [15]:
flair_test_writer = FlairDataWriter(ner_test_path, DATA_PATH, rel_test_path)

In [7]:
flair_test_writer.save_ner_data('flair_ner_test')

100%|██████████| 931604/931604 [01:11<00:00, 13047.09it/s]


In [16]:
flair_test_writer.save_rel_data('flair_rel_test')

100%|██████████| 23462/23462 [00:04<00:00, 5018.79it/s]
