Skip to content

Commit

Permalink
adds edge2vec conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
JSybrandt committed May 22, 2020
1 parent 4f700fe commit 24605be
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scripts/train_2020.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

DISTRIBUTED=${1:-"0"}

#RELEASE_DIR="/zfs/safrolab/users/jsybran/agatha/data/releases/2020"
RELEASE_DIR="/burstbuffer/fast/covid/2020_release"
MODEL_DIR="$RELEASE_DIR/hypothesis_predictor"
RELEASE_DIR="/zfs/safrolab/users/jsybran/agatha/data/releases/2020"
#RELEASE_DIR="/burstbuffer/fast/covid/2020_release"
MODEL_DIR="$RELEASE_DIR/hypothesis_predictor_test"
mkdir -p $MODEL_DIR

NUM_NODES=1
Expand Down
75 changes: 75 additions & 0 deletions tools/py_scripts/edge2vec_conversion/semmeddb_to_edge2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
"""
Converts SemMedDB CSV for input into Edge2Vec
Edge2Vec input format is <source id> <target id> <relation id> <edge id>
For example
```
1 4 1 1
1 5 1 2
1 6 1 3
1 7 1 4
2 7 1 5
2 8 1 6
2 9 1 7
2 10 1 8
```
"""

from fire import Fire
from pathlib import Path
from agatha.util import semmeddb_util as sm
from agatha.util import entity_types as typs
import pickle
from itertools import product
from typing import Dict

def get_or_add_idx(name:str, name2idx:Dict[str, int])->int:
# Either int or None
idx = name2idx.get(name)
if idx is None:
# ids should start at 1
idx = name2idx[name] = len(name2idx) + 1
return idx

def main(
semmeddb_csv_path:Path,
output_edge_list:Path,
output_index:Path,
cut_date:sm.Datestr=None,
):
semmeddb_csv_path = Path(semmeddb_csv_path)
output_edge_list = Path(output_edge_list)
output_index = Path(output_index)
assert semmeddb_csv_path.is_file()
assert not output_edge_list.exists()
assert not output_index.exists()

index = dict(
node2idx={},
relation2idx={},
)
get_or_add_node = lambda n: get_or_add_idx(n, index["node2idx"])
get_or_add_relation = lambda n: get_or_add_idx(n, index["relation2idx"])

predicates = sm.parse(semmeddb_csv_path)
if cut_date is not None:
predicates = sm.filter_by_date(predicates, cut_date)

edge_idx = 1
with open(output_edge_list, 'w') as out_edge_file:
for predicate in predicates:
# Its actually only one ID
sub = get_or_add_node(predicate["subj_ids"])
obj = get_or_add_node(predicate["obj_ids"])
vrb = get_or_add_relation(predicate["pred_type"])
out_edge_file.write(f"{sub} {obj} {vrb} {edge_idx}\n")
edge_idx += 1
with open(output_index, 'wb') as out_index_file:
pickle.dump(index, out_index_file)


if __name__=="__main__":
Fire(main)

0 comments on commit 24605be

Please sign in to comment.