Skip to content

Commit

Permalink
Fix sql edge count logic
Browse files Browse the repository at this point in the history
  • Loading branch information
j6k4m8 committed Apr 17, 2024
1 parent c688855 commit 7220b98
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 64 deletions.
14 changes: 6 additions & 8 deletions grand/backends/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,14 +289,12 @@ def get_edge_by_id(self, u: Hashable, v: Hashable):
"""
if self._directed:
return (
self._edge_df[
(self._edge_df[self._edge_df_source_column] == u)
& (self._edge_df[self._edge_df_target_column] == v)
]
.iloc[0]
.to_dict()
)
result = self._edge_df[
(self._edge_df[self._edge_df_source_column] == u)
& (self._edge_df[self._edge_df_target_column] == v)
]
if len(result):
return self._edge_as_dict(result.iloc[0])

else:
left = self._edge_df[
Expand Down
131 changes: 77 additions & 54 deletions grand/backends/_sqlbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def __init__(
self._node_table.create(self._engine, checkfirst=True)

source_column = sqlalchemy.Column(
self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
self._edge_source_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
)

target_column = sqlalchemy.Column(
self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
self._edge_target_key, sqlalchemy.String(_DEFAULT_SQL_STR_LEN)
)

# Create edges table
Expand All @@ -92,7 +92,7 @@ def __init__(
),
sqlalchemy.Column("_metadata", sqlalchemy.JSON),
source_column,
target_column
target_column,
)
self._edge_table.create(self._engine, checkfirst=True)

Expand Down Expand Up @@ -156,10 +156,13 @@ def add_node(self, node_name: Hashable, metadata: dict) -> Hashable:
return node_name

def add_nodes_from(self, nodes_for_adding, **attr):
nodes = [{
self._primary_key: node,
"_metadata": {**attr, **metadata},
} for node, metadata in nodes_for_adding]
nodes = [
{
self._primary_key: node,
"_metadata": {**attr, **metadata},
}
for node, metadata in nodes_for_adding
]

self._connection.execute(self._node_table.insert(), nodes)

Expand Down Expand Up @@ -204,7 +207,9 @@ def all_nodes_as_iterable(self, include_metadata: bool = False) -> Generator:
if include_metadata:
sql = self._node_table.select()
else:
sql = self._node_table.select().with_only_columns(self._node_table.c[self._primary_key])
sql = self._node_table.select().with_only_columns(
self._node_table.c[self._primary_key]
)

results = []
for x in self._connection.execute(sql):
Expand Down Expand Up @@ -277,12 +282,15 @@ def add_edge(self, u: Hashable, v: Hashable, metadata: dict):
return pk

def add_edges_from(self, ebunch_to_add, **attr):
edges = [{
self._primary_key: f"__{u}__{v}",
self._edge_source_key: u,
self._edge_target_key: v,
"_metadata": {**attr, **metadata},
} for u, v, metadata in ebunch_to_add]
edges = [
{
self._primary_key: f"__{u}__{v}",
self._edge_source_key: u,
self._edge_target_key: v,
"_metadata": {**attr, **metadata},
}
for u, v, metadata in ebunch_to_add
]

self._connection.execute(self._edge_table.insert(), edges)

Expand All @@ -299,7 +307,7 @@ def all_edges_as_iterable(self, include_metadata: bool = False) -> Generator:

columns = [
self._node_table.c[self._edge_source_key],
self._node_table.c[self._edge_target_key]
self._node_table.c[self._edge_target_key],
]

if include_metadata:
Expand Down Expand Up @@ -345,28 +353,26 @@ def get_edge_by_id(self, u: Hashable, v: Hashable):
"""
if self._directed:
pk = f"__{u}__{v}"
return (
self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._primary_key] == pk
)
result = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._primary_key] == pk
)
.fetchone()
._metadata
)
).fetchone()
if result:
return result._metadata
raise KeyError(f"Edge {u}-{v} not found.")
else:
return (
self._connection.execute(
self._edge_table.select().where(
or_(
(self._edge_table.c[self._primary_key] == f"__{u}__{v}"),
(self._edge_table.c[self._primary_key] == f"__{v}__{u}"),
)
result = self._connection.execute(
self._edge_table.select().where(
or_(
(self._edge_table.c[self._primary_key] == f"__{u}__{v}"),
(self._edge_table.c[self._primary_key] == f"__{v}__{u}"),
)
)
.fetchone()
._metadata
)
).fetchone()
if result:
return result._metadata
raise KeyError(f"Edge {u}-{v} not found.")

def get_node_neighbors(
self, u: Hashable, include_metadata: bool = False
Expand All @@ -384,18 +390,20 @@ def get_node_neighbors(

if self._directed:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_source_key] == str(u)
).order_by(self._edge_table.c[self._primary_key])
self._edge_table.select()
.where(self._edge_table.c[self._edge_source_key] == str(u))
.order_by(self._edge_table.c[self._primary_key])
).fetchall()
else:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.select()
.where(
or_(
(self._edge_table.c[self._edge_source_key] == str(u)),
(self._edge_table.c[self._edge_target_key] == str(u)),
)
).order_by(self._edge_table.c[self._primary_key])
)
.order_by(self._edge_table.c[self._primary_key])
).fetchall()

res = [x._asdict() for x in res]
Expand Down Expand Up @@ -436,18 +444,20 @@ def get_node_predecessors(
"""
if self._directed:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.c[self._edge_target_key] == str(u)
).order_by(self._edge_table.c[self._primary_key])
self._edge_table.select()
.where(self._edge_table.c[self._edge_target_key] == str(u))
.order_by(self._edge_table.c[self._primary_key])
).fetchall()
else:
res = self._connection.execute(
self._edge_table.select().where(
self._edge_table.select()
.where(
or_(
(self._edge_table.c[self._edge_target_key] == str(u)),
(self._edge_table.c[self._edge_source_key] == str(u)),
)
).order_by(self._edge_table.c[self._primary_key])
)
.order_by(self._edge_table.c[self._primary_key])
).fetchall()

res = [x._asdict() for x in res]
Expand All @@ -473,7 +483,7 @@ def get_node_predecessors(
]
)

def get_node_count(self) -> Iterable:
def get_node_count(self) -> int:
"""
Get an integer count of the number of nodes in this graph.
Expand All @@ -488,6 +498,21 @@ def get_node_count(self) -> Iterable:
select(func.count()).select_from(self._node_table)
).scalar()

def get_edge_count(self) -> int:
"""
Get an integer count of the number of edges in this graph.
Arguments:
None
Returns:
int: The count of edges
"""
return self._connection.execute(
select(func.count()).select_from(self._edge_table)
).scalar()

def out_degrees(self, nbunch=None):
"""
Return the in-degree of each node in the graph.
Expand All @@ -503,7 +528,9 @@ def out_degrees(self, nbunch=None):
if nbunch is None:
where_clause = None
elif isinstance(nbunch, (list, tuple)):
where_clause = self._edge_table.c[self._edge_source_key].in_([str(x) for x in nbunch])
where_clause = self._edge_table.c[self._edge_source_key].in_(
[str(x) for x in nbunch]
)
else:
# single node:
where_clause = self._edge_table.c[self._edge_source_key] == str(nbunch)
Expand All @@ -524,10 +551,7 @@ def out_degrees(self, nbunch=None):
if where_clause is not None:
query = query.where(where_clause)

results = {
r[0]: r[1]
for r in self._connection.execute(query)
}
results = {r[0]: r[1] for r in self._connection.execute(query)}

if nbunch and not isinstance(nbunch, (list, tuple)):
return results.get(nbunch, 0)
Expand All @@ -548,7 +572,9 @@ def in_degrees(self, nbunch=None):
if nbunch is None:
where_clause = None
elif isinstance(nbunch, (list, tuple)):
where_clause = self._edge_table.c[self._edge_target_key].in_([str(x) for x in nbunch])
where_clause = self._edge_table.c[self._edge_target_key].in_(
[str(x) for x in nbunch]
)
else:
# single node:
where_clause = self._edge_table.c[self._edge_target_key] == str(nbunch)
Expand All @@ -569,10 +595,7 @@ def in_degrees(self, nbunch=None):
if where_clause is not None:
query = query.where(where_clause)

results = {
r[0]: r[1]
for r in self._connection.execute(query)
}
results = {r[0]: r[1] for r in self._connection.execute(query)}

if nbunch and not isinstance(nbunch, (list, tuple)):
return results.get(nbunch, 0)
Expand Down
16 changes: 16 additions & 0 deletions grand/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ def has_node(self, u: Hashable) -> bool:
except KeyError:
return False

def has_edge(self, u: Hashable, v: Hashable) -> bool:
"""
Return true if the edge exists in the graph.
Arguments:
u (Hashable): The source node ID
v (Hashable): The target node ID
Returns:
bool: True if the edge exists
"""
try:
return self.get_edge_by_id(u, v) is not None
except KeyError:
return False

def add_edge(self, u: Hashable, v: Hashable, metadata: dict):
"""
Add a new edge to the graph between two nodes.
Expand Down
8 changes: 8 additions & 0 deletions grand/backends/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
os.environ.get("TEST_NETWORKXBACKEND", default="1") != "1",
reason="NetworkX Backend skipped because $TEST_NETWORKXBACKEND != 1.",
),
id="NetworkXBackend",
),
]
backend_test_params = [
Expand All @@ -52,6 +53,7 @@
os.environ.get("TEST_DATAFRAMEBACKEND", default="1") != "1",
reason="DataFrameBackend skipped because $TEST_DATAFRAMEBACKEND != 1.",
),
id="DataFrameBackend",
),
]

Expand All @@ -63,6 +65,7 @@
os.environ.get("TEST_DYNAMODB", default="1") != "1",
reason="DynamoDB Backend skipped because $TEST_DYNAMODB != 0 or boto3 is not installed",
),
id="DynamoDBBackend",
),
)

Expand All @@ -74,6 +77,7 @@
os.environ.get("TEST_SQLBACKEND", default="1") != "1",
reason="SQL Backend skipped because $TEST_SQLBACKEND != 1 or sqlalchemy is not installed.",
),
id="SQLBackend",
),
)
if _CAN_IMPORT_IGRAPH:
Expand All @@ -84,6 +88,7 @@
os.environ.get("TEST_IGRAPHBACKEND", default="1") != "1",
reason="IGraph Backend skipped because $TEST_IGRAPHBACKEND != 1 or igraph is not installed.",
),
id="IGraphBackend",
),
)
if _CAN_IMPORT_NETWORKIT:
Expand All @@ -94,6 +99,7 @@
os.environ.get("TEST_NETWORKIT", default="1") != "1",
reason="Networkit Backend skipped because $TEST_NETWORKIT != 1 or networkit is not installed.",
),
id="NetworkitBackend",
),
)

Expand All @@ -107,6 +113,7 @@
os.environ.get("TEST_NETWORKITBACKEND") != "1",
reason="Networkit Backend skipped because $TEST_NETWORKITBACKEND != 1.",
),
id="NetworkitBackend",
),
)

Expand All @@ -120,6 +127,7 @@
os.environ.get("TEST_IGRAPHBACKEND") != "1",
reason="Networkit Backend skipped because $TEST_IGRAPHBACKEND != 1.",
),
id="IGraphBackend",
),
)

Expand Down
8 changes: 6 additions & 2 deletions grand/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ def __len__(self):
def number_of_nodes(self):
return self.parent.backend.get_node_count()

def number_of_edges(self):
return self.parent.backend.get_edge_count()
def number_of_edges(self, u=None, v=None):
if u is None and v is None:
return self.parent.backend.get_edge_count()
# Get the number of edges between u and v. because we don't support
# multigraphs, this is 1 if there is an edge, 0 otherwise.
return 1 if self.parent.backend.has_edge(u, v) else 0


class IGraphDialect(nx.Graph):
Expand Down

0 comments on commit 7220b98

Please sign in to comment.