Skip to content

Commit

Permalink
moving clean unseen func to datasets.py, excluding unseen entities in…
Browse files Browse the repository at this point in the history
… wn18rr and fb15k-237 right after loaded
  • Loading branch information
chanlevan committed Mar 22, 2019
1 parent 914d6b9 commit ec151c2
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 128 deletions.
78 changes: 75 additions & 3 deletions ampligraph/datasets/datasets.py
Expand Up @@ -19,6 +19,78 @@
logger.setLevel(logging.DEBUG)



def _clean_data(X, throw_valid=False):
train = X["train"]
valid = X["valid"]
test = X["test"]

train_ent = set(train.flatten())
valid_ent = set(valid.flatten())
test_ent = set(test.flatten())

# not throwing the unseen entities in validation set
if not throw_valid:
train_valid_ent = set(train.flatten()) | set(valid.flatten())
ent_test_diff_train_valid = test_ent - train_valid_ent
idxs_test = []

if len(ent_test_diff_train_valid) > 0:
count_test = 0
c_if = 0
for row in test:
tmp = set(row)
if len(tmp & ent_test_diff_train_valid) != 0:
idxs_test.append(count_test)
c_if += 1
count_test = count_test + 1
filtered_test = np.delete(test, idxs_test, axis=0)
logging.debug("fit validation case: shape test: {0} \
- filtered test: {1}: {2} triples \
with unseen entties removed" \
.format(test.shape, filtered_test.shape, c_if))
return {'train': train, 'valid': valid, 'test': filtered_test}

# throwing the unseen entities in validation set
else:
# for valid
ent_valid_diff_train = valid_ent - train_ent
idxs_valid = []
if len(ent_valid_diff_train) > 0:
count_valid = 0
c_if = 0
for row in valid:
tmp = set(row)
if len(tmp & ent_valid_diff_train) != 0:
idxs_valid.append(count_valid)
c_if += 1
count_valid = count_valid + 1
filtered_valid = np.delete(valid, idxs_valid, axis=0)
logging.debug("not fitting validation case: shape valid: {0} \
- filtered valid: {1}: {2} triples \
with unseen entties removed" \
.format(valid.shape, filtered_valid.shape, c_if))
# for test
ent_test_diff_train = test_ent - train_ent
idxs_test = []
if len(ent_test_diff_train) > 0:
count_test = 0
c_if = 0
for row in test:
tmp = set(row)
if len(tmp & ent_test_diff_train) != 0:
idxs_test.append(count_test)
c_if += 1
count_test = count_test + 1
filtered_test = np.delete(test, idxs_test, axis=0)
logging.debug("not fitting validation case: shape test: {0} \
- filtered test: {1}: {2} triples \
with unseen entties removed" \
.format(test.shape, filtered_test.shape, c_if))

return {'train': train, 'valid': filtered_valid, 'test': filtered_test}


def _get_data_home(data_home=None):
"""Get to location of the dataset folder to use.
Expand Down Expand Up @@ -252,7 +324,7 @@ def load_wn18():
return _load_core_dataset('WN18', data_home=None)


def load_wn18rr():
def load_wn18rr(clean_unseen=True):
""" Load the WN18RR dataset
The dataset is described in :cite:`DettmersMS018`.
Expand Down Expand Up @@ -295,7 +367,7 @@ def load_wn18rr():
"""

return _load_core_dataset('WN18RR', data_home=None)
return _clean_data(_load_core_dataset('WN18RR', data_home=None), throw_valid=True)


def load_fb15k():
Expand Down Expand Up @@ -389,7 +461,7 @@ def load_fb15k_237():
dtype=object)
"""

return _load_core_dataset('FB15K_237', data_home=None)
return _clean_data(_load_core_dataset('FB15K_237', data_home=None), throw_valid=True)


def load_yago3_10():
Expand Down
4 changes: 2 additions & 2 deletions ampligraph/evaluation/__init__.py
Expand Up @@ -4,8 +4,8 @@
from .metrics import mrr_score, mr_score, hits_at_n_score, rank_score
from .protocol import generate_corruptions_for_fit, evaluate_performance, to_idx, \
generate_corruptions_for_eval, create_mappings, select_best_model_ranking, train_test_split_no_unseen, \
filter_unseen_entities, clean_data
filter_unseen_entities

__all__ = ['mrr_score', 'hits_at_n_score', 'rank_score', 'generate_corruptions_for_fit',
'evaluate_performance', 'to_idx', 'generate_corruptions_for_eval', 'create_mappings',
'select_best_model_ranking', 'train_test_split_no_unseen', 'filter_unseen_entities', 'clean_data']
'select_best_model_ranking', 'train_test_split_no_unseen', 'filter_unseen_entities']
69 changes: 0 additions & 69 deletions ampligraph/evaluation/protocol.py
Expand Up @@ -9,75 +9,6 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def clean_data(train, valid, test, throw_valid=False):
"""clean datasets with unseen entities
"""
train_ent = set(train.flatten())
valid_ent = set(valid.flatten())
test_ent = set(test.flatten())

# not throwing the unseen entities in validation set
if not throw_valid:
train_valid_ent = set(train.flatten()) | set(valid.flatten())
ent_test_diff_train_valid = test_ent - train_valid_ent
idxs_test = []

if len(ent_test_diff_train_valid) > 0:
count_test = 0
c_if = 0
for row in test:
tmp = set(row)
if len(tmp & ent_test_diff_train_valid) != 0:
idxs_test.append(count_test)
c_if += 1
count_test = count_test + 1
filtered_test = np.delete(test, idxs_test, axis=0)
logging.debug("fit validation case: shape test: {0} \
- filtered test: {1}: {2} triples \
with unseen entties removed" \
.format(test.shape, filtered_test.shape, c_if))
return valid, filtered_test

# throwing the unseen entities in validation set
else:
# for valid
ent_valid_diff_train = valid_ent - train_ent
idxs_valid = []
if len(ent_valid_diff_train) > 0:
count_valid = 0
c_if = 0
for row in valid:
tmp = set(row)
if len(tmp & ent_valid_diff_train) != 0:
idxs_valid.append(count_valid)
c_if += 1
count_valid = count_valid + 1
filtered_valid = np.delete(valid, idxs_valid, axis=0)
logging.debug("not fitting validation case: shape valid: {0} \
- filtered valid: {1}: {2} triples \
with unseen entties removed" \
.format(valid.shape, filtered_valid.shape, c_if))
# for test
ent_test_diff_train = test_ent - train_ent
idxs_test = []
if len(ent_test_diff_train) > 0:
count_test = 0
c_if = 0
for row in test:
tmp = set(row)
if len(tmp & ent_test_diff_train) != 0:
idxs_test.append(count_test)
c_if += 1
count_test = count_test + 1
filtered_test = np.delete(test, idxs_test, axis=0)
logging.debug("not fitting validation case: shape test: {0} \
- filtered test: {1}: {2} triples \
with unseen entties removed" \
.format(test.shape, filtered_test.shape, c_if))
return filtered_valid, filtered_test


def train_test_split_no_unseen(X, test_size=5000, seed=0):
"""Split into train and test sets.
Expand Down
1 change: 0 additions & 1 deletion experiments/config.json
@@ -1,6 +1,5 @@
{
"CUDA_VISIBLE_DEVICES": "0",
"DATASET_WITH_UNSEEN_ENTITIES": ["WN18RR", "FB15K-237"],
"load_function_map": {
"WN18": "load_wn18",
"FB15K": "load_fb15k",
Expand Down
11 changes: 1 addition & 10 deletions experiments/predictive_performance.py
@@ -1,6 +1,6 @@
import ampligraph.datasets
import ampligraph.latent_features
from ampligraph.evaluation import hits_at_n_score, mr_score, evaluate_performance, mrr_score, clean_data
from ampligraph.evaluation import hits_at_n_score, mr_score, evaluate_performance, mrr_score

import argparse
import os
Expand Down Expand Up @@ -82,15 +82,6 @@ def run_single_exp(config, dataset, model):
X = load_func()
# logging.debug("Loaded...{0}...".format(dataset))

if dataset in config["DATASET_WITH_UNSEEN_ENTITIES"]:
logging.debug("{0} contains unseen entities \
in test dataset, we cleaned them..." \
.format(dataset))
X["valid"], X["test"] = clean_data(X["train"],
X["valid"],
X["test"],
throw_valid=True)

# load model
model_class = getattr(ampligraph.latent_features,
config["model_name_map"][model])
Expand Down
7 changes: 1 addition & 6 deletions setup.py
Expand Up @@ -26,11 +26,6 @@
'beautifultable>=0.7.0',
'pyyaml>=3.13',
'rdflib>=4.2.2'
],
extras_require={
'cpu': ['tensorflow>=1.12.0,<2.0'],
'gpu': ['tensorflow-gpu>=1.12.0,<2.0'],
}
)
])
if __name__ == '__main__':
setup(**setup_params)
49 changes: 12 additions & 37 deletions tests/ampligraph/datasets/test_datasets.py
Expand Up @@ -38,31 +38,15 @@ def test_load_fb15k():

def test_load_fb15k_237():
fb15k_237 = load_fb15k_237()
assert len(fb15k_237['train']) == 272115
assert len(fb15k_237['valid']) == 17535
assert len(fb15k_237['test']) == 20466
assert len(fb15k_237['train']) == 272115

# - 9 because 9 triples containing unseen entities are removed
assert len(fb15k_237['valid']) == 17535 - 9

# ent_train = np.union1d(np.unique(fb15k_237["train"][:,0]), np.unique(fb15k_237["train"][:,2]))
# ent_valid = np.union1d(np.unique(fb15k_237["valid"][:,0]), np.unique(fb15k_237["valid"][:,2]))
# ent_test = np.union1d(np.unique(fb15k_237["test"][:,0]), np.unique(fb15k_237["test"][:,2]))
# distinct_ent = np.union1d(np.union1d(ent_train, ent_valid), ent_test)
# distinct_rel = np.union1d(np.union1d(np.unique(fb15k_237["train"][:,1]), np.unique(fb15k_237["train"][:,1])), np.unique(fb15k_237["train"][:,1]))

# assert len(distinct_ent) == 14541
# assert len(distinct_rel) == 237


# train_all_ent = set(fb15k_237['train'].flatten())
# valid_all_ent = set(fb15k_237['valid'].flatten())
# test_all_ent = set(fb15k_237['test'].flatten())
# - 28 because 28 triples containing unseen entities are removed
assert len(fb15k_237['test']) == 20466 - 28

# unseen_valid = valid_all_ent - train_all_ent
# train_valid_ent = (valid_all_ent - unseen_valid) | train_all_ent

# unseen_test = test_all_ent - train_valid_ent

# assert len(unseen_valid) == 8
# assert len(unseen_test) == 29



def test_yago_3_10():
Expand Down Expand Up @@ -92,19 +76,10 @@ def test_wn18rr():
np.unique(wn18rr["train"][:, 1]))

assert len(wn18rr['train']) == 86835
assert len(wn18rr['valid']) == 3034
assert len(wn18rr['test']) == 3134
# assert len(distinct_ent) == 40943
# assert len(distinct_rel) == 11

# train_all_ent = set(wn18rr['train'].flatten())
# valid_all_ent = set(wn18rr['valid'].flatten())
# test_all_ent = set(wn18rr['test'].flatten())

# unseen_valid = valid_all_ent - train_all_ent
# train_valid_ent = (valid_all_ent - unseen_valid) | train_all_ent

# unseen_test = test_all_ent - train_valid_ent
# - 210 because 210 triples containing unseen entities are removed
assert len(wn18rr['valid']) == 3034 - 210

# assert len(unseen_valid) == 198
# assert len(unseen_test) == 209
# - 210 because 210 triples containing unseen entities are removed
assert len(wn18rr['test']) == 3134 - 210

0 comments on commit ec151c2

Please sign in to comment.