In [1]:
import os
import json
import importlib
from collections import defaultdict
from tqdm import tqdm
from llm import LLM

In [2]:
data = json.load(open('datasets/retacred/train.json'))

re_stats = defaultdict(int)
for sent in data:
    re_stats[sent['relation']] += 1
ignore_keys = set(['org:website', 'per:city_of_birth'])
aug_keys = set()
for key, value in sorted(re_stats.items(), key=lambda x: x[1]):
    if value < 300 and key not in ignore_keys:
        # print(f'---{key}: {value}')
        aug_keys.add(key)
print(aug_keys)
aug_sents = [sent for sent in data if sent['relation'] in aug_keys]
print(len(aug_sents))

{'org:political/religious_affiliation', 'per:cause_of_death', 'org:founded_by', 'org:shareholders', 'per:parents', 'org:dissolved', 'per:countries_of_residence', 'org:founded', 'per:siblings', 'per:city_of_death', 'per:origin', 'per:stateorprovinces_of_residence', 'per:country_of_death', 'per:spouse', 'per:date_of_death', 'per:children', 'per:other_family', 'per:stateorprovince_of_death', 'per:stateorprovince_of_birth', 'per:religion', 'per:schools_attended', 'per:charges', 'per:date_of_birth', 'org:number_of_employees/members', 'per:cities_of_residence', 'per:country_of_birth'}
3447


In [3]:
llm = LLM()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [26]:
import aug


importlib.reload(aug)

template = """
You are an editor who is very good at paraphrasing sentences. Your task is rewrite a given sentence well keeping the original entities.

In a sentence, two entities are nested in the sentence in the format of [[ head entity ]] and << tail entity >>.
Rewrite the given sentence using each given entity exactly once and do not introduce other entities.
Nest the original entities in the same format in the rewrited sentence.
You can change the content inside the entity.

%s
"""

encoded_sents = aug.get_encoded_sents(aug_sents)
save_folder = 'augs'

for relation, cur_encoded_sents in sorted(encoded_sents.items(), key=lambda x: len(x[1])):
    print('--------> processing', relation, len(cur_encoded_sents))
    file_name = f'{relation.replace("/", "--")}.json'
    save_path = os.path.join(save_folder, file_name)
    if os.path.exists(file_name):
        print(f'skip {relation} due to {save_path} exists')
        continue
    rewrited_sents = []
    for encoded_sent, head_ent_type, tail_ent_type in tqdm(cur_encoded_sents):
        message = template % encoded_sent
        rewrited_sent = aug.rewrite_sent(llm, message)
        if rewrited_sent:
            print('*' * 30)
            print(encoded_sent)
            print(rewrited_sent)
            rewrited_sents.append((rewrited_sent, head_ent_type, tail_ent_type))
    print(f'--------> {len(rewrited_sents)} sentences generated')
    print('*' * 50)
    with open(file_name, 'w') as af:
        json.dump(rewrited_sents, af)

--------> processing per:country_of_death 6


  0%|          | 0/6 [00:00<?, ?it/s]

 17%|█▋        | 1/6 [01:12<06:02, 72.49s/it]

******************************
Nasrallah begins his fifth consecutive term as secretary-general of Hezbollah , a post he has held since an Israeli helicopter gunship killed his predecessor , Sheik [[ Abbas Musawi ]] , in south << Lebanon >> in 1992 .
['[[', 'Abbas', 'Musawi', ']]', 'was', 'succeeded', 'by', 'Nasrallah', 'in', 'the', 'role', 'of', 'secretary-general', 'of', 'Hezbollah', ',', 'a', 'position', 'he', 'assumed', 'following', 'his', "predecessor's", 'death', 'in', 'a', '1992', 'Israeli', 'helicopter', 'attack', 'in', 'the', 'southern', 'region', 'of', '<<', 'Lebanon', '>>', '.']


In [10]:
from glob import glob

aug_files = glob('augs/*.json')
cnt = 0
for aug_file in aug_files:
    cur_data = json.load(open(aug_file))
    cnt += len(cur_data)
print(f'generated {cnt} sentences')

433
