In [3]:
import json
import tqdm
import sys

sys.path.append('..')
from preprocess_data import (raw_train_data_file, raw_dev_data_file, raw_test_data_file, processed_dev_data_file, processed_test_data_file, processed_train_data_file, generate_seq)

In [4]:
raw_train_data = json.load(open(raw_train_data_file))['entries']

In [5]:
data_samples = [data_sample[list(data_sample.keys())[0]] for data_sample in raw_train_data]

In [6]:
new_data_samples = []
for data_sample in tqdm.tqdm(data_samples):
    new_data_sample = data_sample.copy()
    unused_labels = ['originaltriplesets', 'xml_id', 'size', 'shape', 'shape_type']
    for label in unused_labels:
        new_data_sample.pop(label)
    triples = new_data_sample['modifiedtripleset']
    new_data_sample['input_seq'], new_data_sample['properties'] = generate_seq(triples)
    new_data_sample['target_sents'] = [sent['lex'] for sent in new_data_sample.pop('lexicalisations')]
    new_data_samples.append(new_data_sample)

100%|██████████| 12876/12876 [00:00<00:00, 42597.49it/s]


In [7]:
new_data_samples[0]

{'category': 'Airport',
 'modifiedtripleset': [{'object': '"Aarhus, Denmark"',
   'property': 'cityServed',
   'subject': 'Aarhus_Airport'}],
 'input_seq': ['Aarhus_Airport', '[P]', '"Aarhus, Denmark"'],
 'properties': ['cityServed'],
 'target_sents': ['The Aarhus is the airport of Aarhus, Denmark.',
  'Aarhus Airport serves the city of Aarhus, Denmark.']}

In [None]:
json.dumps(new_data_samples[0])

## Check entity

In [8]:
with open(processed_train_data_file) as f_in:
    train_data = [json.loads(line) for line in f_in]
with open(processed_dev_data_file) as f_in:
    dev_data = [json.loads(line) for line in f_in]

In [9]:
train_data[0]

{'category': 'Airport',
 'modifiedtripleset': [{'object': '"Aarhus, Denmark"',
   'property': 'cityServed',
   'subject': 'Aarhus_Airport'}],
 'input_seq': ['Aarhus_Airport', '[P]', '"Aarhus, Denmark"'],
 'properties': ['cityServed'],
 'target_sents': ['The Aarhus is the airport of Aarhus, Denmark.',
  'Aarhus Airport serves the city of Aarhus, Denmark.']}

In [10]:
train_entities_set = set()
train_relation_set = set()
for data in train_data:
    train_entities_set.update([tri['subject'] for tri in data['modifiedtripleset']])
    train_entities_set.update([tri['object'] for tri in data['modifiedtripleset']])
    train_relation_set.update([tri['property'] for tri in data['modifiedtripleset']])

dev_entities_set = set()
dev_relation_set = set()
train_relation_list = []
for data in dev_data:
    dev_entities_set.update([tri['subject'] for tri in data['modifiedtripleset']])
    dev_entities_set.update([tri['object'] for tri in data['modifiedtripleset']])
    dev_relation_set.update([tri['property'] for tri in data['modifiedtripleset']])
    train_relation_list += [tri['property'] for tri in data['modifiedtripleset']]

from collections import Counter

c = Counter(train_relation_list)
print(len(train_entities_set))
print(len(dev_entities_set))
uncovered_entities_set = dev_entities_set - train_entities_set
print(len(uncovered_entities_set))
print('')
print(len(train_relation_set))
print(len(dev_relation_set))
uncovered_relation_set = dev_relation_set - train_relation_set
print(len(uncovered_relation_set))

3105
2003
19

373
288
0


In [15]:
c.most_common(288)

[('country', 349),
 ('leaderName', 207),
 ('location', 195),
 ('birthPlace', 194),
 ('isPartOf', 185),
 ('club', 138),
 ('ethnicGroup', 112),
 ('associatedBand/associatedMusicalArtist', 109),
 ('language', 97),
 ('genre', 94),
 ('region', 86),
 ('deathPlace', 84),
 ('ingredient', 83),
 ('capital', 82),
 ('ground', 53),
 ('battles', 53),
 ('manager', 48),
 ('nationality', 46),
 ('birthDate', 45),
 ('almaMater', 45),
 ('mainIngredients', 40),
 ('leaderTitle', 39),
 ('creator', 38),
 ('runwayLength', 36),
 ('operatingOrganisation', 35),
 ('demonym', 34),
 ('league', 34),
 ('city', 34),
 ('epoch', 34),
 ('cityServed', 32),
 ('successor', 32),
 ('elevationAboveTheSeaLevel_(in_metres)', 31),
 ('office (workedAt, workedAs)', 30),
 ('party', 29),
 ('numberOfMembers', 29),
 ('owner', 29),
 ('background', 28),
 ('engine', 28),
 ('was a crew member of', 28),
 ('dishVariation', 27),
 ('builder', 27),
 ('discoverer', 26),
 ('manufacturer', 26),
 ('stylisticOrigin', 25),
 ('operator', 25),
 ('occupa

In [12]:
len(c)

288

In [31]:
import re
def camel_case_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]
def split_label(label):
    label_split = re.split('[^a-zA-Z0-9\n\.]', label)
    return sum([camel_case_split(label2) for label2 in label_split],[])

In [1]:
import yaml

In [2]:
with open('test.yaml') as f_in:
    config = yaml.safe_load(f_in)

In [3]:
config.keys()

dict_keys(['dataset', 'val_split_ratio', 'pretrained_word_emb_name', 'out_dir', 'graph_construction_args', 'gl_metric_type', 'gl_epsilon', 'gl_top_k', 'gl_num_heads', 'gl_num_hidden', 'gl_smoothness_ratio', 'gl_sparsity_ratio', 'gl_connectivity_ratio', 'init_adj_alpha', 'word_dropout', 'rnn_dropout', 'no_fix_word_emb', 'emb_strategy', 'gnn', 'gnn_direction_option', 'gnn_num_layers', 'num_hidden', 'graph_pooling', 'max_pool_linear_proj', 'gnn_dropout', 'gat_attn_dropout', 'gat_negative_slope', 'gat_num_heads', 'gat_num_out_heads', 'gat_residual', 'graphsage_aggreagte_type', 'seed', 'batch_size', 'epochs', 'patience', 'lr', 'lr_patience', 'lr_reduce_factor', 'num_workers', 'gpu', 'no_cuda'])