Skip to content

Commit 9a60ce4

Browse files
committed
--swap_source_target and --drop_invalid args for convenience
1 parent 0215d58 commit 9a60ce4

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

gp_learner.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,8 @@ def main(
16821682
splitting_variant='random',
16831683
train_filename=None,
16841684
test_filename=None,
1685+
swap_source_target=False,
1686+
drop_invalid=False,
16851687
init_patterns_filename=None,
16861688
print_train_test_sets=True,
16871689
reset=False,
@@ -1708,14 +1710,18 @@ def main(
17081710
timer_start = datetime.utcnow()
17091711
main_start = timer_start
17101712

1713+
gsa = partial(
1714+
get_semantic_associations,
1715+
swap_source_target=swap_source_target,
1716+
drop_invalid=drop_invalid,
1717+
)
17111718
if not train_filename and not test_filename:
17121719
# get semantic association pairs and split in train and test sets
1713-
semantic_associations = get_semantic_associations(associations_filename)
1720+
semantic_associations = gsa(associations_filename)
17141721
assocs_train, assocs_test = split_training_test_set(
17151722
semantic_associations, variant=splitting_variant
17161723
)
17171724
else:
1718-
gsa = get_semantic_associations
17191725
assocs_train = gsa(train_filename) if train_filename else []
17201726
assocs_test = gsa(test_filename) if test_filename else []
17211727
if predict == 'train_set':

ground_truth_tools.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def URIRefify(links):
8181
return tuple([URIRef(l) for l in links])
8282

8383

84-
def get_semantic_associations(fn=None, limit=None):
84+
def get_semantic_associations(
85+
fn=None, limit=None, swap_source_target=False, drop_invalid=False
86+
):
8587
if not fn:
8688
verified_mappings = get_verified_mappings()
8789
semantic_associations = get_dbpedia_pairs_from_mappings(
@@ -105,7 +107,31 @@ def get_semantic_associations(fn=None, limit=None):
105107
break
106108
source = from_n3(row['source'].decode('UTF-8'))
107109
target = from_n3(row['target'].decode('UTF-8'))
108-
semantic_associations.append((source, target))
110+
111+
for x in (source, target):
112+
# noinspection PyBroadException
113+
try:
114+
x.n3()
115+
except Exception:
116+
if drop_invalid:
117+
logger.warning(
118+
'ignoring ground truth pair %r: %r cannot be '
119+
'serialized as N3',
120+
(row['source'], row['target']), x
121+
)
122+
break
123+
else:
124+
logger.error(
125+
'error in ground truth pair %r: %r cannot be '
126+
'serialized as N3',
127+
(row['source'], row['target']), x
128+
)
129+
raise
130+
else:
131+
semantic_associations.append((source, target))
132+
if swap_source_target:
133+
logger.info('swapping all (source, target) pairs: (s,t) --> (t,s)')
134+
semantic_associations = [(t, s) for s, t in semantic_associations]
109135
return semantic_associations
110136

111137

run.py

+16
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,22 @@
6464
default=None,
6565
)
6666

67+
parser.add_argument(
68+
"--swap_source_target",
69+
help="allows to turn the ground truth source-target-pairs around for "
70+
"all following considerations: (s,t) --> (t,s)",
71+
action="store_true",
72+
default=False,
73+
)
74+
parser.add_argument(
75+
"--drop_invalid",
76+
help="drops invalid ground truth source-target-pairs (i.e., invalid N3 "
77+
"pairs (e.g., due to bad (URI) encoding)). Will still warn about "
78+
"them.",
79+
action="store_true",
80+
default=False,
81+
)
82+
6783
parser.add_argument(
6884
"--print_train_test_sets",
6985
help="prints the sets used for training and testing",

0 commit comments

Comments
 (0)