In [17]:
from easydict import EasyDict
import logging
import os
from typing import List, Tuple
from utils import InputExample, InputFeatures
import json
import pandas as pd
import random

In [18]:
args = EasyDict({
    "batch_size": 32,
    "data_dir" : "./data",
    "model_dir": "./model",
    "model_tarname":"klue-re.tar.gz",
    "output_dir":os.environ.get("SM_OUTPUT_DATA_DIR", "/output"),
    "max_seq_length":512,
    "relation_filename" : "relation_list.json",
    "train_filename" : "klue-re-v1.1_train.json",
    "valid_filename" : "klue-re-v1.1_dev.json",
    "num_workers" : 4
})
# 릴레이션 데이터 위치
relation_class_file_path = os.path.join(args.data_dir, args.relation_filename)
# train 데이터 위치
train_file_path = os.path.join(args.data_dir, args.train_filename)
aug_entity_swap_file_path = os.path.join(args.data_dir, 'train_aug_entity_swap.json')
aug_aeda_file_path = os.path.join(args.data_dir, 'train_aug_aeda.json')

In [19]:
with open(relation_class_file_path, "r", encoding="utf-8") as f:
    relation_class = json.load(f)["relations"]

In [20]:
train_df = pd.read_json(train_file_path,orient='recode')

In [21]:
def aeda(row,p,punctuations):
    sentence = row['sentence']
    sub_entity = row['subject_entity']
    obj_entity = row['object_entity']

    sub_start = sub_entity['start_idx']
    sub_end = sub_entity['end_idx']

    obj_start = obj_entity['start_idx']
    obj_end = obj_entity['end_idx']
    new_sentence =''
    sub_add = 0
    obj_add = 0
    for i in range(len(sentence)):
        if sentence[i] != ' ':
            new_sentence += sentence[i]
        else:
            prob = random.random()
            if prob < p:

                if not ((sub_start <= i <= sub_end) or (obj_start <= i <= obj_end)):
                    punc_idx = random.randint(0,len(punctuations) - 1)
                    add_punc = ' '+punctuations[punc_idx]+' '
                    new_sentence += add_punc
                    
                    if sub_end <= obj_start:
                        if i <= sub_start:
                            sub_add += 2
                            obj_add += 2

                        elif sub_end <= i <= obj_start:
                            obj_add += 2
                    elif obj_end <= sub_start:
                        if i <= obj_start:
                            sub_add += 2
                            obj_add += 2
                        elif obj_end <= i <= sub_start:
                            sub_add += 2
                else:
                    new_sentence += ' '
            else:
                new_sentence += ' '
        
    return new_sentence, sub_add, obj_add

In [22]:
def entity_swap(row, new_label):
    sentence = row['sentence']
    sub_entity = row['subject_entity']
    obj_entity = row['object_entity']

    n_data = {}
    n_data["guid"] = row["guid"]
    n_data["sentence"] = sentence
    n_data["subject_entity"] = obj_entity
    n_data["object_entity"] = sub_entity
    n_data['label'] = new_label
    n_data['source'] = row['source']

    return n_data

In [23]:
punctuations=[".", ",", "!", "?", ";", ":"]

In [24]:
aeda_list = [
        #"no_relation",
        "org:dissolved",
        "org:founded",
        "org:place_of_headquarters",
        "org:alternate_names",
        "org:member_of",
        "org:members",
        "org:political/religious_affiliation",
        "org:product",
        "org:founded_by",
        "org:top_members/employees",
        "org:number_of_employees/members",
        "per:date_of_birth",
        "per:date_of_death",
        "per:place_of_birth",
        "per:place_of_death",
        "per:place_of_residence",
        "per:origin",
        "per:employee_of",
        "per:schools_attended",
        "per:alternate_names",
        "per:parents",
        "per:children",
        "per:siblings",
        "per:spouse",
        "per:other_family",
        "per:colleagues",
        "per:product",
        "per:religion",
        "per:title"
    ]

In [9]:
rel_data = train_df[train_df['label'] == "org:member_of"].reset_index()
aug_data1 = []
for i in range(len(rel_data)):
    aug_data1.append(entity_swap(rel_data.iloc[i], "org:members"))

aug_data2 = []
rel_data = train_df[train_df['label'] == "org:members"].reset_index()
for i in range(len(rel_data)):
    aug_data2.append(entity_swap(rel_data.iloc[i], "org:member_of"))

aug_data3 = []
rel_data = train_df[train_df['label'] == "per:other_family"].reset_index()
for i in range(len(rel_data)):
    aug_data3.append(entity_swap(rel_data.iloc[i], "per:other_family"))

aug_data4 = []
rel_data = train_df[train_df['label'] == "per:colleagues"].reset_index()
for i in range(len(rel_data)):
    aug_data4.append(entity_swap(rel_data.iloc[i], "per:colleagues"))

aug_data5 = []
rel_data = train_df[train_df['label'] == "per:parents"].reset_index()
for i in range(len(rel_data)):
    aug_data5.append(entity_swap(rel_data.iloc[i], "per:children"))

aug_data6 = []
rel_data = train_df[train_df['label'] == "per:children"].reset_index()
for i in range(len(rel_data)):
    aug_data6.append(entity_swap(rel_data.iloc[i], "per:parents"))

In [10]:
aug_data = []
aug_data.extend(aug_data1)
aug_data.extend(aug_data2)
aug_data.extend(aug_data3)
aug_data.extend(aug_data4)
aug_data.extend(aug_data5)
aug_data.extend(aug_data6)

In [11]:
len(aug_data)

3834

In [12]:
with open(aug_entity_swap_file_path, "w") as f:
    json.dump(aug_data, f,ensure_ascii=False,indent=4)

In [25]:
aeda_data = []

for relation in aeda_list:
    rel_data = train_df[train_df['label'] == relation].reset_index()
    for i in range(len(rel_data)):
        #print(rel_data.iloc[i])
        ns, sa, oa = aeda(rel_data.iloc[i],0.7,punctuations)
        sub_entity = rel_data.iloc[i]["subject_entity"]
        obj_entity = rel_data.iloc[i]["object_entity"]
        n_data = {}
        n_data["guid"] = rel_data.iloc[i]["guid"]
        n_data["sentence"] = ns
        n_data["subject_entity"] = {
            "word" : sub_entity['word'], ###
            "start_idx" : sub_entity["start_idx"] + sa,
            "end_idx" : sub_entity["end_idx"] + sa,
            "type" : sub_entity["type"]
        }
        n_data["object_entity"] = {
            "word" : obj_entity['word'], ###
            "start_idx" : obj_entity["start_idx"] + oa,
            "end_idx" : obj_entity["end_idx"] + oa,
            "type" : obj_entity["type"]
        }
        n_data['label'] = rel_data.iloc[i]["label"]
        n_data['source'] = rel_data.iloc[i]['source']
        aeda_data.append(n_data)


In [26]:
with open(aug_aeda_file_path, "w") as f:
    json.dump(aeda_data, f,ensure_ascii=False,indent=4)

In [27]:
for i in range(len(aeda_data)):
    sen = aeda_data[i]['sentence']
    sub_e = aeda_data[i]['subject_entity']
    obj_e = aeda_data[i]['object_entity']
    #print(sen)

    obj_word = obj_e['word']
    o_s = obj_e['start_idx']
    o_e = obj_e['end_idx']

    sub_word = sub_e['word']
    s_s = sub_e['start_idx']
    s_e = sub_e['end_idx']

    if sub_word != sen[s_s : s_e + 1]:
        print(i)
        break
    if obj_word != sen[o_s : o_e + 1]:
        print(i)
        break


In [28]:
len(aeda_data)

22936