Skip to content

Commit

Permalink
Updates edge2vec conversion tools for training
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 23, 2020
1 parent c1c7b1f commit 22936f9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
28 changes: 18 additions & 10 deletions tools/py_scripts/edge2vec_conversion/index_edge2vec_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
from pathlib import Path
from fire import Fire
import pickle
from typing import Iterable, List, Tuple
from typing import Iterable, List, Tuple, Dict
from tqdm import tqdm
from agatha.ml.hypothesis_predictor.predicate_util import clean_coded_term


def iterate_vectors(vector_text_path:Path)->Iterable[Tuple[str, List[float]]]:
def iterate_vectors(
vector_text_path:Path,
idx2node:Dict[str, int],
predicate_keys:Iterable[str],
)->Iterable[Tuple[str, List[float]]]:
assert vector_text_path.is_file()
num_vectors = None
expected_dim = None
Expand All @@ -37,7 +41,10 @@ def iterate_vectors(vector_text_path:Path)->Iterable[Tuple[str, List[float]]]:
if expected_dim is not None and len(tokens) == expected_dim + 1:
idx = int(tokens[0])
vec = [float(t) for t in tokens[1:]]
yield idx, vec
yield clean_coded_term(idx2node[idx]), vec
assert expected_dim is not None
for pred_key in predicate_keys:
yield pred_key, [0]*expected_dim


def main(
Expand All @@ -57,15 +64,16 @@ def main(
assert len(list(output_dir.iterdir())) == 0

with open(input_index_path, 'rb') as pkl_file:
node2idx = pickle.load(pkl_file)["node2idx"]
index = pickle.load(pkl_file)

idx2node = {i: n for n, i in node2idx.items()}
idx2node = {i: n for n, i in index["node2idx"].items()}

node2vec = dict(tqdm(iterate_vectors(
input_vector_text_path,
idx2node=idx2node,
predicate_keys=index["predicate_keys"]
)))

node2vec = {
clean_coded_term(idx2node[idx]): vec
for idx, vec
in tqdm(iterate_vectors(input_vector_text_path))
}
setup_embedding_lookup_data(
node2vec,
test_name="edge2vec",
Expand Down
2 changes: 2 additions & 0 deletions tools/py_scripts/edge2vec_conversion/semmeddb_to_edge2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def main(
index = dict(
node2idx={},
relation2idx={},
predicate_keys=set()
)
get_or_add_node = lambda n: get_or_add_idx(n, index["node2idx"])
get_or_add_relation = lambda n: get_or_add_idx(n, index["relation2idx"])
Expand All @@ -67,6 +68,7 @@ def main(
vrb = get_or_add_relation(predicate["pred_type"])
out_edge_file.write(f"{sub} {obj} {vrb} {edge_idx}\n")
edge_idx += 1
index["predicate_keys"].add(sm.predicate_to_key(predicate))
with open(output_index, 'wb') as out_index_file:
pickle.dump(index, out_index_file)

Expand Down

0 comments on commit 22936f9

Please sign in to comment.