In [1]:
import re
from io import open
import pandas as pd
from sklearn.model_selection import train_test_split

In [1]:
RELATION_LABELS = ['Other', 'Message-Topic(e1,e2)', 'Message-Topic(e2,e1)',
                   'Product-Producer(e1,e2)', 'Product-Producer(e2,e1)',
                   'Instrument-Agency(e1,e2)', 'Instrument-Agency(e2,e1)',
                   'Entity-Destination(e1,e2)', 'Entity-Destination(e2,e1)',
                   'Cause-Effect(e1,e2)', 'Cause-Effect(e2,e1)',
                   'Component-Whole(e1,e2)', 'Component-Whole(e2,e1)',
                   'Entity-Origin(e1,e2)', 'Entity-Origin(e2,e1)',
                   'Member-Collection(e1,e2)', 'Member-Collection(e2,e1)',
                   'Content-Container(e1,e2)', 'Content-Container(e2,e1)']

In [2]:
def format_semeval_inputs(input_file, labels=None):

    with open(input_file) as raw_text:

        data = []
        lines = [line.strip() for line in raw_text]

        if labels: #train text:

            for idx in range(0, len(lines), 4):

                id = lines[idx].split("\t")[0]
                relation = labels.index(lines[idx + 1])
                sentence = lines[idx].split("\t")[1][1:-1]

                data.append([id, sentence, relation])
            
            return pd.DataFrame(data=data, columns=["id", "sentence", "relation"])

        else: #test text:
            
            for idx in range(0, len(lines)):

                id = lines[idx].split("\t")[0]
                sentence = lines[idx].split("\t")[1][1:-1]

                data.append([id, sentence])

            return pd.DataFrame(data=data, columns=["id", "sentence"])

In [3]:
def semeval_train_dev_split(input_file, train_file, dev_file, labels, test_size):

    train_df, dev_df = train_test_split(format_semeval_inputs(input_file, labels), test_size=test_size)

    train_df.to_csv(train_file, sep='\t', index=False, header=False)
    dev_df.to_csv(dev_file, sep='\t', index=False, header=False)

In [6]:
semeval_train_dev_split(input_file='./data/TRAIN_FILE.TXT', 
                        train_file='./data/train.tsv', 
                        dev_file='./data/dev.tsv', 
                        labels=RELATION_LABELS, 
                        test_size=0.25)