In [1]:
import os
import json
import importlib
from collections import defaultdict
from tqdm import tqdm
from llm import LLM

In [2]:
data = json.load(open('datasets/retacred/train.json'))

re_stats = defaultdict(int)
for sent in data:
    re_stats[sent['relation']] += 1
ignore_keys = set(['org:website', 'per:city_of_birth'])
aug_keys = set()
for key, value in sorted(re_stats.items(), key=lambda x: x[1]):
    if value < 300 and key not in ignore_keys:
        # print(f'---{key}: {value}')
        aug_keys.add(key)
print(aug_keys)
aug_sents = [sent for sent in data if sent['relation'] in aug_keys]
print(len(aug_sents))

{'org:political/religious_affiliation', 'per:cause_of_death', 'org:founded_by', 'org:shareholders', 'per:parents', 'org:dissolved', 'per:countries_of_residence', 'org:founded', 'per:siblings', 'per:city_of_death', 'per:origin', 'per:stateorprovinces_of_residence', 'per:country_of_death', 'per:spouse', 'per:date_of_death', 'per:children', 'per:other_family', 'per:stateorprovince_of_death', 'per:stateorprovince_of_birth', 'per:religion', 'per:schools_attended', 'per:charges', 'per:date_of_birth', 'org:number_of_employees/members', 'per:cities_of_residence', 'per:country_of_birth'}
3447


In [3]:
llm = LLM()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
import aug


importlib.reload(aug)

template = """
You are an editor who is very good at paraphrasing sentences. Your task is rewrite a given sentence well keeping the original entities.

In a sentence, two entities are nested in the sentence in the format of [[ entity ]].
Rewrite the given sentence using each given entity exactly once and do not introduce other entities.
Nest the original entities in the same format in the rewrited sentence.
You can change the content inside the entity.

%s
"""

encoded_sents = aug.get_encoded_sents(aug_sents)
save_folder = 'augs'

for relation, cur_encoded_sents in sorted(encoded_sents.items(), key=lambda x: len(x[1])):
    print('--------> processing', relation, len(cur_encoded_sents))
    file_name = f'{relation.replace("/", "--")}.json'
    save_path = os.path.join(save_folder, file_name)
    if os.path.exists(save_path):
        print(f'skip {relation} due to {save_path} exists')
        continue
    rewrited_sents = []
    for encoded_sent, head_ent_type, tail_ent_type, head_ent, tail_ent in tqdm(cur_encoded_sents):
        message = template % encoded_sent
        rewrited_sent = aug.rewrite_sent(llm, message)
        if rewrited_sent:
            # print('*' * 30)
            # print(encoded_sent)
            # print(rewrited_sent)
            rewrited_sents.append((rewrited_sent, head_ent_type, tail_ent_type, head_ent, tail_ent))
    print(f'--------> {len(rewrited_sents)} sentences generated')
    print('*' * 50)
    with open(save_path, 'w') as af:
        json.dump(rewrited_sents, af)

### Check how many sentences augmented

In [55]:
import os
import json
from glob import glob

aug_files = glob('augs/*.json')
cnt = 0
relations = set()
for aug_file in aug_files:
    cur_data = json.load(open(aug_file))
    cnt += len(cur_data)
    relation = os.path.basename(aug_file).split('_')[0]
    relations.add(relation)
print(f'generated {cnt} sentences for {len(relations)} relations')

generated 14104 sentences for 18


### transform augmented sentences into target data format

In [None]:
from rapidfuzz import fuzz

def transform_aug_file(aug_file):
    relation = os.path.basename(aug_file).rsplit('_', 1)[0].replace('--', '/')
    cur_data = json.load(open(aug_file))
    new_cur_data = []
    for idx, (words, head_ent_type, tail_ent_type, head_words, tail_words) in enumerate(cur_data):
        org_sent = ' '.join(words)
        words, head_ent_type, tail_ent_type, head_words, tail_words = cur_data[idx]
        # print('----->', ' '.join(words))
        head_start = words.index('[[') - 0
        head_end = words.index(']]') - 1
        tail_start = words.index('[[', head_start+1) - 1
        tail_end = words.index(']]', head_end+2) - 2
        words.remove('[[')
        words.remove(']]')

        if head_start >= head_end or tail_start >= tail_end:
            print('======>', org_sent)
            continue

        assert head_end < len(words)
        assert tail_end < len(words)
        assert head_start >= 0
        assert tail_start >= 0
        assert head_start < head_end
        assert tail_start < tail_end

        # print(relation, '---=======>', words[head_start: head_end])
        # print(relation, '---=======>', words[tail_start: tail_end])
        # print(relation, '---=======>', head_words)
        # print(relation, '---=======>', tail_words)

        cur_head_words = words[head_start: head_end]
        cur_tail_words = words[tail_start: tail_end]
        head_head_ratio = fuzz.ratio(' '.join(cur_head_words), ' '.join(head_words))
        head_tail_ratio = fuzz.ratio(' '.join(cur_head_words), ' '.join(tail_words))
        tail_head_ratio = fuzz.ratio(' '.join(cur_tail_words), ' '.join(head_words))
        tail_tail_ratio = fuzz.ratio(' '.join(cur_tail_words), ' '.join(tail_words))
        if head_head_ratio < head_tail_ratio:
            head_start, head_end, tail_start, tail_end = tail_start, tail_end, head_start, head_end

        sid = f'{os.path.basename(aug_file)[:-5]}_{idx}'
        sent = {
            'id': sid, 'token': words, 'subj_start': head_start, 'subj_end': head_end-1,
            'obj_start': tail_start, 'obj_end': tail_end-1, 'relation': relation,
            'subj_type': head_ent_type, 'obj_type': tail_ent_type}
        new_cur_data.append(sent)
    return new_cur_data

aug_files = glob('augs/*.json')
sents = []
for aug_file in aug_files:
    sents.extend(transform_aug_file(aug_file))

train_data = json.load(open('datasets/retacred/train.json'))
json.dump(sents+train_data, open('datasets/retacred_aug/train.json', 'w'))

print(f'transformed {len(sents)}, total {len(sents+train_data)}')

### Check data consistency

In [62]:
from collections import defaultdict

org_stats = defaultdict(int)
for sent in train_data:
    org_stats[sent['relation']] += 1

print(len(org_stats))

stats = defaultdict(int)
for sent in sents+train_data:
    stats[sent['relation']] += 1

print(len(stats))

print(set(stats.keys()) - set(org_stats.keys()))

40
40
set()


### Compare test results

In [98]:
def get_pred_stats(train_file, test_file, pred_file):
    train_data = json.load(open(train_file))
    test_data = json.load(open(test_file))
    gt_id2label = {sent['id']: sent['relation'] for sent in test_data if sent['relation'] != 'no_relation'}
    
    train_stats = defaultdict(int)
    for sent in train_data:
        if sent['relation'] == 'no_relation':
            continue
        train_stats[sent['relation']] += 1

    test_class_stats = defaultdict(set)
    test_pos_sids = set()
    for sent in test_data:
        if sent['relation'] == 'no_relation':
            continue
        gt_id2label[sent['id']] = sent['relation']
        test_class_stats[sent['relation']].add(sent['id'])
        test_pos_sids.add((sent['id'], sent['relation']))

    with open(pred_file) as pred_f:
        lines = pred_f.readlines()
    all_preds = set()
    class_preds = defaultdict(set)
    for line in lines:
        sid, pred = line.strip().split()
        if pred == 'no_relation':
            continue
        all_preds.add((sid, pred))
        class_preds[pred].add(sid)

    TP = len(all_preds & test_pos_sids)
    FP = len(all_preds - test_pos_sids)
    FN = len(test_pos_sids - all_preds)
    # print('pred =========>', all_preds)
    # print('GT =========>', test_pos_sids)
    print(TP, FP, FN)
    prec = TP / (TP + FP)
    recall = TP / (TP + FN)
    f1 = 2 * prec * recall / (prec + recall)

    print(TP, FP, FN, prec, recall, f1)

    stats = {'overall': {'TP': TP, 'FP': FP, 'FN': FN, 'support': TP+FN, 'prec': prec, 'recall': recall, 'f1': f1}}
    for relation in test_class_stats.keys():
        class_TP = len(class_preds[relation] & test_class_stats[relation])
        class_FP = len(class_preds[relation] - test_class_stats[relation])
        class_FN = len(test_class_stats[relation] - class_preds[relation])
        # print(relation, '=====>', TP, FP, FN)
        class_prec = class_TP / (class_TP + class_FP) if (class_TP + class_FP) else 0
        class_recall = class_TP / (class_TP + class_FN) if (class_TP + class_FN) else 0
        class_f1 = 2 * class_prec * class_recall / (class_prec + class_recall) if (class_prec + class_recall) else 0
        stats[relation] = {'TP': class_TP, 'FP': class_FP, 'FN': class_FN, 'support': class_TP+class_FN, 'prec': class_prec, 'recall': class_recall, 'f1': class_f1}
        # print(relation, '=====>', class_prec, class_recall, class_f1)

    return train_stats, stats

In [99]:
train_file = 'datasets/retacred/train.json'
test_file = 'datasets/retacred/test.json'
pred_file = 'tacred_dir/predictions.txt'
train_stats, stats = get_pred_stats(train_file, test_file, pred_file)

4614 907 1034
4614 907 1034 0.8357181669987321 0.8169263456090652 0.8262154176739189


In [100]:
train_file = 'datasets/retacred_aug/train.json'
test_file = 'datasets/retacred_aug/test.json'
pred_file = 'tacred_dir_aug/predictions.txt'
aug_train_stats, aug_stats = get_pred_stats(train_file, test_file, pred_file)

4737 934 911
4737 934 911 0.8353024157996826 0.8387039660056658 0.8369997349589186


In [105]:
from prettytable import PrettyTable

test_stats = defaultdict(int)
test_data = json.load(open(test_file))
for sent in test_data:
    if sent['relation'] == 'no_relation':
        continue
    test_stats[sent['relation']] += 1
test_stats['overall'] = sum(test_stats.values())
aug_train_stats['overall'] = sum(aug_train_stats.values())

# print('====>', len(stats), len(aug_stats), stats.keys())
table = PrettyTable()
table.field_names = ['Relation', 'TestSupport', 'TrainSupport', 'Prec', 'Recall', 'F1', 'AugTrainSupport', 'AugPrec', 'AugRecall', 'AugF1']
for relation, train_support in [('overall', sum(train_stats.values()))]+sorted(train_stats.items(), key=lambda x: x[1]):
    if relation == 'no_relation':
        continue
    row = [relation, test_stats[relation], train_support]
    if relation in stats:
        row.extend([round(stats[relation]['prec'], 4), round(stats[relation]['recall'], 4), round(stats[relation]['f1'], 4)])
    else:
        row.extend([0.0, 0.0, 0.0])
    row.append(aug_train_stats[relation])
    if relation in aug_stats:
        row.extend([round(aug_stats[relation]['prec'], 4), round(aug_stats[relation]['recall'], 4), round(aug_stats[relation]['f1'], 4)])
    else:
        row.extend([0.0, 0.0, 0.0])
    table.add_row(row)

print(table)

with open('table.csv', 'w', newline='') as f_output:
    f_output.write(table.get_csv_string())

+-------------------------------------+-------------+--------------+--------+--------+--------+-----------------+---------+-----------+--------+
|               Relation              | TestSupport | TrainSupport |  Prec  | Recall |   F1   | AugTrainSupport | AugPrec | AugRecall | AugF1  |
+-------------------------------------+-------------+--------------+--------+--------+--------+-----------------+---------+-----------+--------+
|               overall               |     5648    |    19704     | 0.8357 | 0.8169 | 0.8262 |      68672      |  0.8353 |   0.8387  | 0.837  |
|         per:country_of_death        |      14     |      6       |   0    |  0.0   |   0    |        48       |   1.0   |   0.3571  | 0.5263 |
|            org:dissolved            |      5      |      23      |   0    |  0.0   |   0    |       141       |   0.5   |    0.2    | 0.2857 |
|         per:country_of_birth        |      0      |      25      |  0.0   |  0.0   |  0.0   |       182       |   0.0   |    0.0