In [4]:
import os
import json
import importlib
from collections import defaultdict
from tqdm import tqdm
from llm import LLM

In [2]:
data = json.load(open('datasets/retacred/train.json'))

re_stats = defaultdict(int)
for sent in data:
    re_stats[sent['relation']] += 1
ignore_keys = set(['org:website', 'per:city_of_birth'])
aug_keys = set()
for key, value in sorted(re_stats.items(), key=lambda x: x[1]):
    if value < 300 and key not in ignore_keys:
        # print(f'---{key}: {value}')
        aug_keys.add(key)
print(aug_keys)
aug_sents = [sent for sent in data if sent['relation'] in aug_keys]
print(len(aug_sents))

{'org:political/religious_affiliation', 'per:cause_of_death', 'org:founded_by', 'org:shareholders', 'per:parents', 'org:dissolved', 'per:countries_of_residence', 'org:founded', 'per:siblings', 'per:city_of_death', 'per:origin', 'per:stateorprovinces_of_residence', 'per:country_of_death', 'per:spouse', 'per:date_of_death', 'per:children', 'per:other_family', 'per:stateorprovince_of_death', 'per:stateorprovince_of_birth', 'per:religion', 'per:schools_attended', 'per:charges', 'per:date_of_birth', 'org:number_of_employees/members', 'per:cities_of_residence', 'per:country_of_birth'}
3447


In [3]:
llm = LLM()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
import aug


importlib.reload(aug)

template = """
You are an editor who is very good at paraphrasing sentences. Your task is rewrite a given sentence well keeping the original entities.

In a sentence, two entities are nested in the sentence in the format of [[ entity ]].
Rewrite the given sentence using each given entity exactly once and do not introduce other entities.
Nest the original entities in the same format in the rewrited sentence.
You can change the content inside the entity.

%s
"""

encoded_sents = aug.get_encoded_sents(aug_sents)
save_folder = 'augs'

for relation, cur_encoded_sents in sorted(encoded_sents.items(), key=lambda x: len(x[1])):
    print('--------> processing', relation, len(cur_encoded_sents))
    file_name = f'{relation.replace("/", "--")}.json'
    save_path = os.path.join(save_folder, file_name)
    if os.path.exists(save_path):
        print(f'skip {relation} due to {save_path} exists')
        continue
    rewrited_sents = []
    for encoded_sent, head_ent_type, tail_ent_type, head_ent, tail_ent in tqdm(cur_encoded_sents):
        message = template % encoded_sent
        rewrited_sent = aug.rewrite_sent(llm, message)
        if rewrited_sent:
            # print('*' * 30)
            # print(encoded_sent)
            # print(rewrited_sent)
            rewrited_sents.append((rewrited_sent, head_ent_type, tail_ent_type, head_ent, tail_ent))
    print(f'--------> {len(rewrited_sents)} sentences generated')
    print('*' * 50)
    with open(save_path, 'w') as af:
        json.dump(rewrited_sents, af)

### Fix data in augmented sentences

In [49]:
import os
import json
from glob import glob

for key in ['train', 'dev']:
    aug_files = glob(f'augs_{key}/*.json')
    for aug_file in aug_files:
        cur_data = json.load(open(aug_file))
        new_data = []
        for idx, (rewrited, head_ent_type, tail_ent_type, head_ent, tail_ent) in enumerate(cur_data):
            sid = f'{os.path.basename(aug_file)[:-5]}_{idx}'
            sent_info = {
                'encoded_sent': [], 'head_ent_type': head_ent_type, 'tail_ent_type': tail_ent_type,
                'rewrited_sent': rewrited, 'head_ent': head_ent, 'tail_ent': tail_ent, 'id': sid}
            new_data.append(sent_info)
        out_file = os.path.join(f'augs_{key}_2', os.path.basename(aug_file))
        json.dump(new_data, open(out_file, 'w'))


### Check how many sentences augmented

In [1]:
import os
import json
from glob import glob

for key in ['train', 'dev', 'test']:
    aug_files = glob(f'augs_{key}/*.json')
    cnt = 0
    relations = set()
    for aug_file in aug_files:
        cur_data = json.load(open(aug_file))
        cnt += len(cur_data)
        relation = os.path.basename(aug_file).split('_')[0]
        relations.add(relation)
    print(f'{key} set: generated {cnt} sentences for {len(relations)} relations')

train set: generated 22346 sentences for 23 relations
dev set: generated 15152 sentences for 23 relations
test set: generated 9471 sentences for 23 relations


### Transform augmented sentences into target data format

In [29]:
import aug
import json
import importlib
from collections import defaultdict

importlib.reload(aug)

for key in ['train', 'dev', 'test']:
    folder = f'augs_{key}'
    sents = aug.get_auged_sents(folder)

    data = json.load(open(f'datasets/retacred/{key}.json'))
    json.dump(sents+data, open(f'datasets/retacred_aug/{key}.json', 'w'))
    print(f'{key} set: transformed {len(sents)}, total {len(sents+data)}')

    # check label consistency
    org_stats = defaultdict(int)
    for sent in data:
        org_stats[sent['relation']] += 1

    stats = defaultdict(int)
    for sent in sents+data:
        stats[sent['relation']] += 1

    assert len(set(stats.keys()) - set(org_stats.keys())) == 0

train set: transformed 22322, total 80787
dev set: transformed 15143, total 34727
test set: transformed 9460, total 22878


### Compare test results

In [2]:
import json
from collections import defaultdict
from prettytable import PrettyTable


def get_pred_stats(train_file, test_file, pred_file, relations = None):
    train_data = json.load(open(train_file))
    test_data = json.load(open(test_file))
    gt_id2label = {sent['id']: sent['relation'] for sent in test_data if sent['relation'] != 'no_relation'}
    
    train_stats = defaultdict(int)
    test_stats = defaultdict(int)
    for sent in train_data:
        if sent['relation'] == 'no_relation':
            continue
        train_stats[sent['relation']] += 1
    train_stats['overall'] = sum(train_stats.values())
    if relations is None:
        relations = [relation for relation, _ in sorted(train_stats.items(), key=lambda x: x[1])]

    test_class_stats = defaultdict(set)
    test_pos_sids = set()
    for sent in test_data:
        if sent['relation'] == 'no_relation':
            continue
        test_stats[sent['relation']] += 1
        sid = sent['id'].split('_')[0]
        gt_id2label[sid] = sent['relation']
        test_class_stats[sent['relation']].add(sid)
        test_pos_sids.add((sid, sent['relation']))
    test_stats['overall'] = sum(test_stats.values())

    with open(pred_file) as pred_f:
        lines = pred_f.readlines()
    preds = defaultdict(list)
    for line in lines:
        sid, pred = line.strip().split()
        # if pred == 'no_relation':
        #     continue
        sid = sid.split('_')[0]
        preds[sid].append(pred)

    all_preds = set()
    class_preds = defaultdict(set)
    for sid, s_preds in preds.items():
        pred = max(set(s_preds), key = s_preds.count)
        if pred == 'no_relation':
            continue
        all_preds.add((sid, pred))
        class_preds[pred].add(sid)

    TP = len(all_preds & test_pos_sids)
    FP = len(all_preds - test_pos_sids)
    FN = len(test_pos_sids - all_preds)
    prec = TP / (TP + FP) if TP + FP else 0.0
    recall = TP / (TP + FN) if TP + FN else 0.0
    f1 = 2 * prec * recall / (prec + recall) if prec + recall else 0.0

    print(TP, FP, FN, prec, recall, f1)

    stats = {'overall': {'TP': TP, 'FP': FP, 'FN': FN, 'support': TP+FN, 'prec': prec, 'recall': recall, 'f1': f1}}
    for relation in test_class_stats.keys():
        class_TP = len(class_preds[relation] & test_class_stats[relation])
        class_FP = len(class_preds[relation] - test_class_stats[relation])
        class_FN = len(test_class_stats[relation] - class_preds[relation])
        # print(relation, '=====>', TP, FP, FN)
        class_prec = class_TP / (class_TP + class_FP) if (class_TP + class_FP) else 0
        class_recall = class_TP / (class_TP + class_FN) if (class_TP + class_FN) else 0
        class_f1 = 2 * class_prec * class_recall / (class_prec + class_recall) if (class_prec + class_recall) else 0
        stats[relation] = {'TP': class_TP, 'FP': class_FP, 'FN': class_FN, 'support': class_TP+class_FN, 'prec': class_prec, 'recall': class_recall, 'f1': class_f1}
    return train_stats, test_stats, stats, relations


def write_stats2table(train_stats, test_stats, stats, relations, out_file):
    table = PrettyTable()
    table.field_names = ['Relation', 'TestSupport', 'TrainSupport', 'Prec', 'Recall', 'F1']
    for relation in relations:
        if relation == 'no_relation':
            continue
        row = [relation, test_stats[relation], train_stats[relation]]
        if relation in stats:
            row.extend([round(stats[relation]['prec'], 4), round(stats[relation]['recall'], 4), round(stats[relation]['f1'], 4)])
        else:
            row.extend([0.0, 0.0, 0.0])
        table.add_row(row)

    with open(out_file, 'w', newline='') as f_output:
        f_output.write(table.get_csv_string())

In [3]:
org_train_file = 'datasets/retacred/train.json'
aug_train_file = 'datasets/retacred_aug/train.json'
org_test_file = 'datasets/retacred/test.json'
aug_test_file = 'datasets/retacred_aug/test.json'
relations = None

for key in ['', '_aug_train', '_aug_test', '_aug_train_aug_dev', '_aug_train_aug_test', '_aug_train_aug_dev_aug_test']:
    print('------>', key)
    if 'aug_train' in key:
        train_file = aug_train_file
    else:
        train_file = org_train_file
    if 'aug_test' in key:
        test_file = aug_test_file
    else:
        test_file = org_test_file
    pred_file = f'tacred_dir{key}/predictions.txt'
    out_file = f'tables/table{key}.csv'
    train_stats, test_stats, stats, relations = get_pred_stats(train_file, test_file, pred_file, relations)
    write_stats2table(train_stats, test_stats, stats, relations, out_file)

------> 
4614 907 1034 0.8357181669987321 0.8169263456090652 0.8262154176739189
------> _aug_train
4755 989 893 0.8278203342618384 0.8418909348441926 0.8347963483146067
------> _aug_test
4284 938 1364 0.8203753351206434 0.7584985835694051 0.7882244710211592
------> _aug_train_aug_dev
4737 978 911 0.8288713910761155 0.8387039660056658 0.8337586904866673
------> _aug_train_aug_test
4938 1026 710 0.8279678068410463 0.8742917847025495 0.850499483293145
------> _aug_train_aug_dev_aug_test
4920 1021 728 0.828143410200303 0.8711048158640227 0.849081025110018
