Skip to content

Commit

Permalink
keep track of entities of interest to avoid hashing them
Browse files Browse the repository at this point in the history
  • Loading branch information
GillesVandewiele committed Mar 24, 2022
1 parent efaffa5 commit 67e5f2a
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 63 deletions.
84 changes: 41 additions & 43 deletions pyrdf2vec/walkers/community.py
Expand Up @@ -39,7 +39,7 @@ class CommunityWalker(Walker):
through probabilities and relations that are not explicitly modeled in a
Knowledge Graph. Similar to the Random walking strategy, the Depth First
Search (DFS) algorithm is used if a maximum number of walks is specified.
Otherwise, the Breath First Search (BFS) algorithm is chosen.
Otherwise, the Breadth First Search (BFS) algorithm is chosen.
Attributes:
_is_support_remote: True if the walking strategy can be used with a
Expand Down Expand Up @@ -155,6 +155,7 @@ def _bfs(
"""
walks: Set[Walk] = {(entity,)}
rng = np.random.RandomState(self.random_state)
for i in range(self.max_depth):
for walk in walks.copy():
if is_reverse:
Expand All @@ -163,45 +164,25 @@ def _bfs(
walks.add((obj, pred) + walk)
if (
obj in self.communities
and np.random.RandomState(
self.random_state
).random()
< self.hop_prob
and rng.random() < self.hop_prob
):
comm = self.communities[obj]
comm_labels = self.labels_per_community[comm]
walks.add(
(
np.random.RandomState(
self.random_state
).choice(
self.labels_per_community[
self.communities[obj]
]
),
)
+ walk
(rng.random().choice(comm_labels),) + walk
)
else:
hops = kg.get_hops(walk[-1])
for pred, obj in hops:
walks.add(walk + (pred, obj))
if (
obj in self.communities
and np.random.RandomState(
self.random_state
).random()
< self.hop_prob
and rng.random() < self.hop_prob
):
comm = self.communities[obj]
comm_labels = self.labels_per_community[comm]
walks.add(
walk
+ (
np.random.RandomState(
self.random_state
).choice(
self.labels_per_community[
self.communities[obj]
]
),
)
walk + (rng.random().choice(comm_labels),)
)
if len(hops) > 0:
walks.remove(walk)
Expand All @@ -227,6 +208,9 @@ def _dfs(
self.sampler.visited = set()
walks: List[Walk] = []
assert self.max_walks is not None

rng = np.random.RandomState(self.random_state)

while len(walks) < self.max_walks:
sub_walk: Walk = (entity,)
d = 1
Expand All @@ -240,35 +224,29 @@ def _dfs(
if is_reverse:
if (
pred_obj[0] in self.communities
and np.random.RandomState(self.random_state).random()
< self.hop_prob
and rng.random() < self.hop_prob
):
community_nodes = self.labels_per_community[
self.communities[pred_obj[0]]
]
sub_walk = (
pred_obj[1],
np.random.RandomState(self.random_state).choice(
community_nodes
),
rng.choice(community_nodes),
community_nodes,
) + sub_walk
else:
sub_walk = (pred_obj[1], pred_obj[0]) + sub_walk
else:
if (
pred_obj[1] in self.communities
and np.random.RandomState(self.random_state).random()
< self.hop_prob
and rng.random() < self.hop_prob
):
community_nodes = self.labels_per_community[
self.communities[pred_obj[1]]
]
sub_walk += (
pred_obj[0],
np.random.RandomState(self.random_state).choice(
community_nodes
),
rng.choice(community_nodes),
community_nodes,
)
else:
Expand Down Expand Up @@ -327,6 +305,29 @@ def extract_walks(self, kg: KG, entity: Vertex) -> List[Walk]:
]
return [walk for walk in fct_search(kg, entity)]

def _map_vertex(self, entity: Vertex, pos: int) -> str:
"""Maps certain vertices to MD5 hashes to save memory. For entities of
interest (provided by the user to the extract function) and predicates,
the string representation is kept.
Args:
entity: The entity to be mapped.
pos: The position of the entity in the walk.
Returns:
A hash (string) or original string representation.
"""
if (
entity.name in self._entities
or pos % 2 == 1
or self.md5_bytes is None
):
return entity.name
else:
ent_hash = md5(entity.name.encode()).digest()
return str(ent_hash[: self.md5_bytes])

def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
"""Extracts random walks for an entity based on a Knowledge Graph.
Expand All @@ -342,10 +343,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
canonical_walks: Set[SWalk] = set()
for walk in self.extract_walks(kg, entity):
canonical_walk: List[str] = [
vertex.name
if i == 0 or i % 2 == 1 or self.md5_bytes is None
else str(md5(vertex.name.encode()).digest()[: self.md5_bytes])
for i, vertex in enumerate(walk)
self._map_vertex(vertex, i) for i, vertex in enumerate(walk)
]
canonical_walks.add(tuple(canonical_walk))
return {entity.name: list(canonical_walks)}
36 changes: 28 additions & 8 deletions pyrdf2vec/walkers/random.py
Expand Up @@ -10,9 +10,9 @@

@attr.s
class RandomWalker(Walker):
"""Random walking strategy which extracts walks from a rood node using the
"""Random walking strategy which extracts walks from a root node using the
Depth First Search (DFS) algorithm if a maximum number of walks is
specified, otherwise the Breath First Search (BFS) algorithm is used.
specified, otherwise the Breadth First Search (BFS) algorithm is used.
Attributes:
_is_support_remote: True if the walking strategy can be used with a
Expand Down Expand Up @@ -51,7 +51,7 @@ def _bfs(
self, kg: KG, entity: Vertex, is_reverse: bool = False
) -> List[Walk]:
"""Extracts random walks for an entity based on Knowledge Graph using
the Breath First Search (BFS) algorithm.
the Breadth First Search (BFS) algorithm.
Args:
kg: The Knowledge Graph.
Expand Down Expand Up @@ -120,7 +120,7 @@ def _dfs(
def extract_walks(self, kg: KG, entity: Vertex) -> List[Walk]:
"""Extracts random walks for an entity based on Knowledge Graph using
the Depth First Search (DFS) algorithm if a maximum number of walks is
specified, otherwise the Breath First Search (BFS) algorithm is used.
specified, otherwise the Breadth First Search (BFS) algorithm is used.
Args:
kg: The Knowledge Graph.
Expand All @@ -139,6 +139,29 @@ def extract_walks(self, kg: KG, entity: Vertex) -> List[Walk]:
]
return [walk for walk in fct_search(kg, entity)]

def _map_vertex(self, entity: Vertex, pos: int) -> str:
"""Maps certain vertices to MD5 hashes to save memory. For entities of
interest (provided by the user to the extract function) and predicates,
the string representation is kept.
Args:
entity: The entity to be mapped.
pos: The position of the entity in the walk.
Returns:
A hash (string) or original string representation.
"""
if (
entity.name in self._entities
or pos % 2 == 1
or self.md5_bytes is None
):
return entity.name
else:
ent_hash = md5(entity.name.encode()).digest()
return str(ent_hash[: self.md5_bytes])

def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
"""Extracts random walks for an entity based on a Knowledge Graph.
Expand All @@ -154,10 +177,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
canonical_walks: Set[SWalk] = set()
for walk in self.extract_walks(kg, entity):
canonical_walk: List[str] = [
vertex.name
if i == 0 or i % 2 == 1 or self.md5_bytes is None
else str(md5(vertex.name.encode()).digest()[: self.md5_bytes])
for i, vertex in enumerate(walk)
self._map_vertex(vertex, i) for i, vertex in enumerate(walk)
]
canonical_walks.add(tuple(canonical_walk))
return {entity.name: list(canonical_walks)}
8 changes: 6 additions & 2 deletions pyrdf2vec/walkers/walker.py
@@ -1,7 +1,7 @@
import multiprocessing
import warnings
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Set

import attr
from tqdm import tqdm
Expand Down Expand Up @@ -101,6 +101,8 @@ class Walker(ABC):
init=False, repr=False, type=bool, default=True
)

_entities = attr.ib(init=False, repr=False, type=Set[str], default=set())

def __attrs_post_init__(self):
if self.n_jobs == -1:
self.n_jobs = multiprocessing.cpu_count()
Expand Down Expand Up @@ -136,7 +138,6 @@ def extract(
"Invalid walking strategy. Please, choose a walking strategy "
+ "that can fetch walks via a SPARQL endpoint server."
)
self.sampler.fit(kg)

process = self.n_jobs if self.n_jobs is not None else 1
if (kg._is_remote and kg.mul_req) and process >= 2:
Expand All @@ -151,6 +152,9 @@ def extract(
if kg._is_remote and kg.mul_req:
kg._fill_hops(entities)

self.sampler.fit(kg)
self._entities |= set(entities)

with multiprocessing.Pool(process, self._init_worker, [kg]) as pool:
res = list(
tqdm(
Expand Down
24 changes: 20 additions & 4 deletions pyrdf2vec/walkers/weisfeiler_lehman.py
Expand Up @@ -152,6 +152,25 @@ def extract(
self._weisfeiler_lehman(kg)
return super().extract(kg, entities, verbose)

def _map_wl(self, entity: Vertex, pos: int, n: int) -> str:
"""Maps certain vertices to MD5 hashes to save memory. For entities of
interest (provided by the user to the extract function) and predicates,
the string representation is kept.
Args:
entity: The entity to be mapped.
pos: The position of the entity in the walk.
n: The iteration number of the WL algorithm.
Returns:
A hash (string) or original string representation.
"""
if entity.name in self._entities or pos % 2 == 1:
return entity.name
else:
return self._label_map[entity][n]

def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
"""Extracts random walks for an entity based on a Knowledge Graph.
Expand All @@ -168,10 +187,7 @@ def _extract(self, kg: KG, entity: Vertex) -> EntityWalks:
for n in range(self.wl_iterations + 1):
for walk in self.extract_walks(kg, entity):
canonical_walk: List[str] = [
vertex.name
if i == 0 or i % 2 == 1
else self._label_map[vertex][n]
for i, vertex in enumerate(walk)
self._map_wl(vertex, i, n) for i, vertex in enumerate(walk)
]
canonical_walks.add(tuple(canonical_walk))
return {entity.name: list(canonical_walks)}
8 changes: 5 additions & 3 deletions tests/walkers/test_halk.py
Expand Up @@ -60,13 +60,14 @@ def test_extract(
self, setup, kg, root, max_depth, max_walks, with_reverse
):
root = f"{URL}#{root}"
walks = HALKWalker(
walker = HALKWalker(
max_depth,
max_walks,
freq_thresholds=[0.001],
with_reverse=with_reverse,
random_state=42,
).extract(kg, [root])
)
walks = walker.extract(kg, [root])

if max_walks is not None:
assert len(walks) == 1
Expand All @@ -76,4 +77,5 @@ def test_extract(
if not with_reverse:
assert walk[0] == root
for obj in walk[2::2]:
assert obj.startswith("b'")
if obj not in walker._entities:
assert obj.startswith("b'")
8 changes: 5 additions & 3 deletions tests/walkers/test_random.py
Expand Up @@ -98,17 +98,19 @@ def test_extract(
self, setup, kg, root, max_depth, max_walks, with_reverse
):
root = f"{URL}#{root}"
walks = RandomWalker(
walker = RandomWalker(
max_depth, max_walks, with_reverse=with_reverse, random_state=42
)._extract(kg, Vertex(root))[root]
)
walks = walker._extract(kg, Vertex(root))[root]
if max_walks is not None:
if with_reverse:
assert len(walks) <= max_walks * max_walks
else:
assert len(walks) <= max_walks
for walk in walks:
for obj in walk[2::2]:
assert obj.startswith("b'")
if obj not in walker._entities:
assert obj.startswith("b'")
if not with_reverse:
assert walk[0] == root
assert len(walk) <= (max_depth * 2) + 1
Expand Down

0 comments on commit 67e5f2a

Please sign in to comment.