Skip to content

Commit

Permalink
Clean up similarity.py and use dataclasses for storing state (network…
Browse files Browse the repository at this point in the history
…x#5831)

* Clean up similarity.py and use dataclasses for storing state

* use nonlocal to stop using an object to store maxcost value
  • Loading branch information
MridulS authored and Alex-Markham committed Oct 13, 2023
1 parent 8310afa commit 747e094
Showing 1 changed file with 28 additions and 40 deletions.
68 changes: 28 additions & 40 deletions networkx/algorithms/similarity.py
Expand Up @@ -16,9 +16,8 @@
import math
import time
import warnings
from functools import reduce
from dataclasses import dataclass
from itertools import product
from operator import mul

import networkx as nx

Expand Down Expand Up @@ -187,7 +186,7 @@ def graph_edit_distance(
"""
bestcost = None
for vertex_path, edge_path, cost in optimize_edit_paths(
for _, _, cost in optimize_edit_paths(
G1,
G2,
node_match,
Expand Down Expand Up @@ -503,7 +502,7 @@ def optimize_graph_edit_distance(
<10.5220/0005209202710278>. <hal-01168816>
https://hal.archives-ouvertes.fr/hal-01168816
"""
for vertex_path, edge_path, cost in optimize_edit_paths(
for _, _, cost in optimize_edit_paths(
G1,
G2,
node_match,
Expand Down Expand Up @@ -672,18 +671,12 @@ def optimize_edit_paths(
import scipy as sp
import scipy.optimize # call as sp.optimize

@dataclass
class CostMatrix:
def __init__(self, C, lsa_row_ind, lsa_col_ind, ls):
# assert C.shape[0] == len(lsa_row_ind)
# assert C.shape[1] == len(lsa_col_ind)
# assert len(lsa_row_ind) == len(lsa_col_ind)
# assert set(lsa_row_ind) == set(range(len(lsa_row_ind)))
# assert set(lsa_col_ind) == set(range(len(lsa_col_ind)))
# assert ls == C[lsa_row_ind, lsa_col_ind].sum()
self.C = C
self.lsa_row_ind = lsa_row_ind
self.lsa_col_ind = lsa_col_ind
self.ls = ls
C: ...
lsa_row_ind: ...
lsa_col_ind: ...
ls: ...

def make_CostMatrix(C, m, n):
# assert(C.shape == (m + n, m + n))
Expand Down Expand Up @@ -724,7 +717,7 @@ def reduce_ind(ind, i):
rind[rind >= k] -= 1
return rind

def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]):
def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=None):
"""
Parameters:
u, v: matched vertices, u=None or v=None for
Expand All @@ -748,7 +741,10 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]):
# only attempt to match edges after one node match has been made
# this will stop self-edges on the first node being automatically deleted
# even when a substitution is the better option
if matched_uv:
if matched_uv is None or len(matched_uv) == 0:
g_ind = []
h_ind = []
else:
g_ind = [
i
for i in range(M)
Expand All @@ -765,9 +761,6 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]):
pending_h[j][:2] in ((q, v), (v, q), (q, q)) for p, q in matched_uv
)
]
else:
g_ind = []
h_ind = []

m = len(g_ind)
n = len(h_ind)
Expand All @@ -778,9 +771,9 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]):

# Forbid structurally invalid matches
# NOTE: inf remembered from Ce construction
for k, i in zip(range(m), g_ind):
for k, i in enumerate(g_ind):
g = pending_g[i][:2]
for l, j in zip(range(n), h_ind):
for l, j in enumerate(h_ind):
h = pending_h[j][:2]
if nx.is_directed(G1) or nx.is_directed(G2):
if any(
Expand Down Expand Up @@ -822,8 +815,7 @@ def reduce_Ce(Ce, ij, m, n):
m_i = m - sum(1 for t in i if t < m)
n_j = n - sum(1 for t in j if t < n)
return make_CostMatrix(reduce_C(Ce.C, i, j, m, n), m_i, n_j)
else:
return Ce
return Ce

def get_edit_ops(
matched_uv, pending_u, pending_v, Cv, pending_g, pending_h, Ce, matched_cost
Expand Down Expand Up @@ -982,8 +974,9 @@ def get_edit_paths(
# assert not len(pending_g)
# assert not len(pending_h)
# path completed!
# assert matched_cost <= maxcost.value
maxcost.value = min(maxcost.value, matched_cost)
# assert matched_cost <= maxcost_value
nonlocal maxcost_value
maxcost_value = min(maxcost_value, matched_cost)
yield matched_uv, matched_gh, matched_cost

else:
Expand Down Expand Up @@ -1051,7 +1044,7 @@ def get_edit_paths(
for y, h in zip(sortedy, reversed(H)):
if h is not None:
pending_h.insert(y, h)
for t in xy:
for _ in xy:
matched_gh.pop()

# Initialization
Expand Down Expand Up @@ -1167,13 +1160,7 @@ def get_edit_paths(
# debug_print(Ce.C)
# debug_print()

class MaxCost:
def __init__(self):
# initial upper-bound estimate
# NOTE: should work for empty graph
self.value = Cv.C.sum() + Ce.C.sum() + 1

maxcost = MaxCost()
maxcost_value = Cv.C.sum() + Ce.C.sum() + 1

if timeout is not None:
if timeout <= 0:
Expand All @@ -1187,10 +1174,11 @@ def prune(cost):
if upper_bound is not None:
if cost > upper_bound:
return True
if cost > maxcost.value:
if cost > maxcost_value:
return True
elif strictly_decreasing and cost >= maxcost.value:
if strictly_decreasing and cost >= maxcost_value:
return True
return False

# Now go!

Expand All @@ -1204,7 +1192,7 @@ def prune(cost):
# assert sorted(G1.edges) == sorted(g for g, h in edge_path if g is not None)
# assert sorted(G2.edges) == sorted(h for g, h in edge_path if h is not None)
# print(vertex_path, edge_path, cost, file = sys.stderr)
# assert cost == maxcost.value
# assert cost == maxcost_value
yield list(vertex_path), list(edge_path), cost


Expand Down Expand Up @@ -1324,9 +1312,9 @@ def simrank(G, u, v):

if isinstance(x, np.ndarray):
if x.ndim == 1:
return {node: val for node, val in zip(G, x)}
else: # x.ndim == 2:
return {u: dict(zip(G, row)) for u, row in zip(G, x)}
return dict(zip(G, x))
# else x.ndim == 2
return {u: dict(zip(G, row)) for u, row in zip(G, x)}
return x


Expand Down

0 comments on commit 747e094

Please sign in to comment.